From 6bdedff45714b42f5bca6e4959b4771ad031fa9b Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 20 Aug 2020 19:00:35 +0100 Subject: [PATCH 01/35] LogFormatter Implementation --- pkg/logs/conn.go | 7 ++++++- pkg/logs/console.go | 18 +++++++++++++++++- pkg/logs/es/es.go | 4 ++++ pkg/logs/file.go | 4 ++++ pkg/logs/jianliao.go | 4 ++++ pkg/logs/log.go | 12 ++++++++++++ pkg/logs/logger.go | 6 +++--- pkg/logs/multifile.go | 4 ++++ pkg/logs/slack.go | 4 ++++ pkg/logs/smtp.go | 4 ++++ 10 files changed, 62 insertions(+), 5 deletions(-) diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index e0560fd9..79ab410c 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -39,6 +39,10 @@ func NewConn() Logger { return conn } +func (c *connWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // Init initializes a connection writer with json config. // json config only needs they "level" key func (c *connWriter) Init(jsonConfig string) error { @@ -62,7 +66,8 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { defer c.innerWriter.Close() } - _, err := c.lg.writeln(lm) + msg := c.Format(lm) + _, err := c.lg.writeln(msg) if err != nil { return err } diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 024152aa..86db6178 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -52,6 +52,20 @@ type consoleWriter struct { Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color } +func (c *consoleWriter) Format(lm *LogMsg) string { + msg := lm.Msg + + if c.Colorful { + msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) + } + + h, _, _ := formatTimeHeader(lm.When) + bytes := append(append(h, msg...), '\n') + + return "eee" + string(bytes) + +} + // NewConsole creates ConsoleWriter returning as LoggerInterface. func NewConsole() Logger { cw := &consoleWriter{ @@ -76,10 +90,12 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if lm.Level > c.Level { return nil } + // fmt.Printf("Formatted: %s\n\n", c.fmtter.Format(lm)) if c.Colorful { lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } - c.lg.writeln(lm) + msg := c.Format(lm) + c.lg.writeln(msg) return nil } diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 5c91b2ed..b70e5cf3 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -35,6 +35,10 @@ type esLogger struct { Level int `json:"level"` } +func (el *esLogger) Format(lm *logs.LogMsg) string { + return lm.Msg +} + // {"dsn":"http://localhost:9200/","level":1} func (el *esLogger) Init(jsonconfig string) error { err := json.Unmarshal([]byte(jsonconfig), el) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 23ea4b09..6b33ebb1 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -89,6 +89,10 @@ func newFileWriter() Logger { return w } +func (w *fileLogWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // Init file logger with json config. // jsonConfig like: // { diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 0e7cfab4..a108342c 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -27,6 +27,10 @@ func (s *JLWriter) Init(jsonconfig string) error { return json.Unmarshal([]byte(jsonconfig), s) } +func (s *JLWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // WriteMsg writes message in smtp writer. // Sends an email with subject and only this message. func (s *JLWriter) WriteMsg(lm *LogMsg) error { diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 37421625..d47173e5 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -86,6 +86,7 @@ type newLoggerFunc func() Logger type Logger interface { Init(config string) error WriteMsg(lm *LogMsg) error + Format(lm *LogMsg) string Destroy() Flush() } @@ -128,6 +129,8 @@ const defaultAsyncMsgLen = 1e3 type nameLogger struct { Logger + // Formatter func(*LogMsg) string + LogFormatter name string } @@ -139,6 +142,10 @@ type LogMsg struct { LineNumber int } +type LogFormatter interface { + Format(lm *LogMsg) string +} + var logMsgPool *sync.Pool // NewLogger returns a new BeeLogger. @@ -179,6 +186,10 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { return bl } +func Format(lm *LogMsg) string { + return lm.Msg +} + // SetLogger provides a given logger adapter into BeeLogger with config string. // config must in in JSON format like {"interval":360}} func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { @@ -237,6 +248,7 @@ func (bl *BeeLogger) DelLogger(adapterName string) error { func (bl *BeeLogger) writeToLoggers(lm *LogMsg) { for _, l := range bl.outputs { + // fmt.Println("Formatted: ", l.Format(lm)) err := l.WriteMsg(lm) if err != nil { fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go index 721c8dc1..d8b334d4 100644 --- a/pkg/logs/logger.go +++ b/pkg/logs/logger.go @@ -30,10 +30,10 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) writeln(lm *LogMsg) (int, error) { +func (lg *logWriter) writeln(msg string) (int, error) { lg.Lock() - h, _, _ := formatTimeHeader(lm.When) - n, err := lg.writer.Write(append(append(h, lm.Msg...), '\n')) + msg += "\n" + n, err := lg.writer.Write([]byte(msg)) lg.Unlock() return n, err } diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index 1cd9e9f8..0650c99d 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -78,6 +78,10 @@ func (f *multiFileLogWriter) Init(config string) error { return nil } +func (f *multiFileLogWriter) Format(lm *LogMsg) string { + return lm.Msg +} + func (f *multiFileLogWriter) Destroy() { for i := 0; i < len(f.writers); i++ { if f.writers[i] != nil { diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index dad4f4ea..c31f9330 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -18,6 +18,10 @@ func newSLACKWriter() Logger { return &SLACKWriter{Level: LevelTrace} } +func (s *SLACKWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // Init SLACKWriter with json config string func (s *SLACKWriter) Init(jsonconfig string) error { return json.Unmarshal([]byte(jsonconfig), s) diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index 0d2b3c29..beadb0d7 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -114,6 +114,10 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return client.Quit() } +func (s *SMTPWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // WriteMsg writes message in smtp writer. // Sends an email with subject and only this message. func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { From 705e091593a49a09904c75896aec1f85aa3c8862 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 20 Aug 2020 19:06:51 +0100 Subject: [PATCH 02/35] Add format call before logging --- pkg/logs/es/es.go | 2 +- pkg/logs/file.go | 7 ++++--- pkg/logs/jianliao.go | 3 +-- pkg/logs/slack.go | 4 ++-- pkg/logs/smtp.go | 4 +++- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index b70e5cf3..06dfece1 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -71,7 +71,7 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { idx := LogDocument{ Timestamp: lm.When.Format(time.RFC3339), - Msg: lm.Msg, + Msg: el.Format(lm), } body, err := json.Marshal(idx) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 6b33ebb1..366fbcf2 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -153,7 +153,8 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { return nil } hd, d, h := formatTimeHeader(lm.When) - lm.Msg = string(hd) + lm.Msg + "\n" + msg := w.Format(lm) + msg = fmt.Sprintf("%s %s\n", string(hd), msg) if w.Rotate { w.RLock() if w.needRotateHourly(len(lm.Msg), h) { @@ -180,10 +181,10 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { } w.Lock() - _, err := w.fileWriter.Write([]byte(lm.Msg)) + _, err := w.fileWriter.Write([]byte(msg)) if err == nil { w.maxLinesCurLines++ - w.maxSizeCurSize += len(lm.Msg) + w.maxSizeCurSize += len(msg) } w.Unlock() return err diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index a108342c..6830bade 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -38,8 +38,7 @@ func (s *JLWriter) WriteMsg(lm *LogMsg) error { return nil } - text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) - + text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) form := url.Values{} form.Add("authorName", s.AuthorName) form.Add("title", s.Title) diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index c31f9330..c0584f72 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -33,8 +33,8 @@ func (s *SLACKWriter) WriteMsg(lm *LogMsg) error { if lm.Level > s.Level { return nil } - - text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) + msg := s.Format(lm) + text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), msg) form := url.Values{} form.Add("payload", text) diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index beadb0d7..d992b279 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -130,11 +130,13 @@ func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { // Set up authentication information. auth := s.getSMTPAuth(hp[0]) + msg := s.Format(lm) + // Connect to the server, authenticate, set the sender and recipient, // and send the email all in one step. contentType := "Content-Type: text/plain" + "; charset=UTF-8" mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + lm.Msg) + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + msg) return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) } From e1da804b2ba54572dc44c76b963d63068a92bad8 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 20 Aug 2020 19:15:27 +0100 Subject: [PATCH 03/35] Add format func to alils --- pkg/logs/alils/alils.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 6c1464f2..2c83e4ee 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -100,6 +100,10 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { return nil } +func (c *aliLSWriter) Format(lm *logs.LogMsg) string { + return lm.Msg +} + // WriteMsg writes a message in connection. // If connection is down, try to re-connect. func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { From 08e49ca3233350f56cce2b57e7bc761c5a9ea277 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 20 Aug 2020 19:32:42 +0100 Subject: [PATCH 04/35] Test empty commit From ed1d2c7f6e2d8589daf69aedf9a9d6e7c5d76d86 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:22:38 +0100 Subject: [PATCH 05/35] Add custom logging format functionality and global formatter functionality --- pkg/logs/conn.go | 25 +++++++++---- pkg/logs/console.go | 30 ++++++++++++--- pkg/logs/file.go | 20 +++++++++- pkg/logs/jianliao.go | 23 ++++++++---- pkg/logs/log.go | 85 ++++++++++++++++++++++++++++++++++++++++--- pkg/logs/multifile.go | 18 ++++++--- pkg/logs/slack.go | 15 ++++++-- pkg/logs/smtp.go | 11 +++++- 8 files changed, 190 insertions(+), 37 deletions(-) diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index 79ab410c..e11909a0 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -23,13 +23,15 @@ import ( // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. @@ -45,7 +47,14 @@ func (c *connWriter) Format(lm *LogMsg) string { // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string) error { +func (c *connWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + c.UseCustomFormatter = true + c.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonConfig), c) } diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 86db6178..a928de7d 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -47,9 +47,11 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color } func (c *consoleWriter) Format(lm *LogMsg) string { @@ -62,7 +64,7 @@ func (c *consoleWriter) Format(lm *LogMsg) string { h, _, _ := formatTimeHeader(lm.When) bytes := append(append(h, msg...), '\n') - return "eee" + string(bytes) + return string(bytes) } @@ -78,10 +80,18 @@ func NewConsole() Logger { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string) error { +func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + c.UseCustomFormatter = true + c.CustomFormatter = elem + } + } + if len(jsonConfig) == 0 { return nil } + return json.Unmarshal([]byte(jsonConfig), c) } @@ -94,7 +104,15 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if c.Colorful { lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } - msg := c.Format(lm) + + msg := "" + + if c.UseCustomFormatter { + msg = c.CustomFormatter(lm) + } else { + msg = c.Format(lm) + } + c.lg.writeln(msg) return nil } diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 366fbcf2..4576e19d 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -60,6 +60,9 @@ type fileLogWriter struct { hourlyOpenDate int hourlyOpenTime time.Time + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string + Rotate bool `json:"rotate"` Level int `json:"level"` @@ -104,7 +107,14 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string) error { +func (w *fileLogWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + w.UseCustomFormatter = true + w.CustomFormatter = elem + } + } + err := json.Unmarshal([]byte(jsonConfig), w) if err != nil { return err @@ -153,7 +163,13 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { return nil } hd, d, h := formatTimeHeader(lm.When) - msg := w.Format(lm) + msg := "" + if w.UseCustomFormatter { + msg = w.CustomFormatter(lm) + } else { + msg = w.Format(lm) + } + msg = fmt.Sprintf("%s %s\n", string(hd), msg) if w.Rotate { w.RLock() diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 6830bade..9877bed6 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -9,12 +9,14 @@ import ( // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } // newJLWriter creates jiaoliao writer. @@ -23,7 +25,14 @@ func newJLWriter() Logger { } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string) error { +func (s *JLWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + s.UseCustomFormatter = true + s.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonconfig), s) } diff --git a/pkg/logs/log.go b/pkg/logs/log.go index d47173e5..fd8fca63 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -84,7 +84,7 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { - Init(config string) error + Init(config string, LogFormatter ...func(*LogMsg) string) error WriteMsg(lm *LogMsg) error Format(lm *LogMsg) string Destroy() @@ -115,6 +115,7 @@ type BeeLogger struct { init bool enableFuncCallDepth bool loggerFuncCallDepth int + globalFormatter func(*LogMsg) string enableFullFilePath bool asynchronous bool prefix string @@ -129,8 +130,6 @@ const defaultAsyncMsgLen = 1e3 type nameLogger struct { Logger - // Formatter func(*LogMsg) string - LogFormatter name string } @@ -206,7 +205,16 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } lg := logAdapter() - err := lg.Init(config) + var err error + + // Global formatter overrides the default set formatter + // but not adapter specific formatters set with logs.SetLoggerWithOpts() + if bl.globalFormatter != nil { + err = lg.Init(config, bl.globalFormatter) + } else { + err = lg.Init(config) + } + if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -248,7 +256,6 @@ func (bl *BeeLogger) DelLogger(adapterName string) error { func (bl *BeeLogger) writeToLoggers(lm *LogMsg) { for _, l := range bl.outputs { - // fmt.Println("Formatted: ", l.Format(lm)) err := l.WriteMsg(lm) if err != nil { fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) @@ -394,6 +401,74 @@ func (bl *BeeLogger) startLogger() { } } +// SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON +// such as: {"interval":360} +func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { + config := append(configs, "{}")[0] + for _, l := range bl.outputs { + if l.name == adapterName { + return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) + } + } + + logAdapter, ok := adapters[adapterName] + if !ok { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + + if formatterFunc == nil { + return fmt.Errorf("No formatter set for %s log adapter", adapterName) + } + + lg := logAdapter() + err := lg.Init(config, formatterFunc) + if err != nil { + fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) + return err + } + + bl.outputs = append(bl.outputs, &nameLogger{ + name: adapterName, + Logger: lg, + }) + + return nil +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + if !bl.init { + bl.outputs = []*nameLogger{} + bl.init = true + } + return bl.setLoggerWithOpts(adapterName, formatterFunc, configs...) +} + +// SetLoggerWIthOpts sets a given log adapter with a custom log adapter. +// Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc} +// where FormatFunc has the signature func(*LogMsg) string +func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { + err := beeLogger.SetLoggerWithOpts(adapter, formatterFunc, config...) + if err != nil { + log.Fatal(err) + } + return nil + +} + +func (bl *BeeLogger) setGlobalFormatter(fmtter func(*LogMsg) string) error { + bl.globalFormatter = fmtter + return nil +} + +// SetGlobalFormatter sets the global formatter for all log adapters +// This overrides and other individually set adapter +func SetGlobalFormatter(fmtter func(*LogMsg) string) error { + return beeLogger.setGlobalFormatter(fmtter) +} + // Emergency Log EMERGENCY level message. func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index 0650c99d..bcd4dd4e 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -24,9 +24,11 @@ import ( // and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log // the rotate attribute also acts like fileLogWriter type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} @@ -44,7 +46,14 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], // } -func (f *multiFileLogWriter) Init(config string) error { +func (f *multiFileLogWriter) Init(config string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + f.UseCustomFormatter = true + f.CustomFormatter = elem + } + } + writer := newFileWriter().(*fileLogWriter) err := writer.Init(config) if err != nil { @@ -74,7 +83,6 @@ func (f *multiFileLogWriter) Init(config string) error { } } } - return nil } diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index c0584f72..9407b48a 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -9,8 +9,10 @@ import ( // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook type SLACKWriter struct { - WebhookURL string `json:"webhookurl"` - Level int `json:"level"` + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } // newSLACKWriter creates jiaoliao writer. @@ -23,7 +25,14 @@ func (s *SLACKWriter) Format(lm *LogMsg) string { } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string) error { +func (s *SLACKWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + s.UseCustomFormatter = true + s.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonconfig), s) } diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index d992b279..b81be68f 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -32,6 +32,8 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } // NewSMTPWriter creates the smtp writer. @@ -50,7 +52,14 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonconfig string) error { +func (s *SMTPWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + s.UseCustomFormatter = true + s.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonconfig), s) } From 48a98ec1a5c7aeb7674b2b09885f3dd9d9e575d4 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:39:53 +0100 Subject: [PATCH 06/35] Fix init for alils.go --- pkg/logs/alils/alils.go | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 2c83e4ee..2300f8f8 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -32,11 +32,13 @@ type Config struct { // aliLSWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + UseCustomFormatter bool + CustomFormatter func(*logs.LogMsg) string Config } @@ -48,7 +50,14 @@ func NewAliLS() logs.Logger { } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string) (err error) { +func (c *aliLSWriter) Init(jsonConfig string, LogFormatter ...func(*logs.LogMsg) string) (err error) { + + for _, elem := range LogFormatter { + if elem != nil { + c.UseCustomFormatter = true + c.CustomFormatter = elem + } + } json.Unmarshal([]byte(jsonConfig), c) @@ -135,6 +144,12 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { lg = c.group[0] } + if c.UseCustomFormatter { + content = c.CustomFormatter(lm) + } else { + content = c.Format(lm) + } + c1 := &LogContent{ Key: proto.String("msg"), Value: proto.String(content), From c5970766a35cbc588c3b59b21d626e296894ffe1 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:41:39 +0100 Subject: [PATCH 07/35] Add init to es.go --- pkg/logs/es/es.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 06dfece1..4dfc4160 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -31,8 +31,10 @@ func NewES() logs.Logger { // import _ "github.com/astaxie/beego/logs/es" type esLogger struct { *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` + DSN string `json:"dsn"` + Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*logs.LogMsg) string } func (el *esLogger) Format(lm *logs.LogMsg) string { @@ -40,7 +42,14 @@ func (el *esLogger) Format(lm *logs.LogMsg) string { } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string) error { +func (el *esLogger) Init(jsonconfig string, LogFormatter ...func(*logs.LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + el.UseCustomFormatter = true + el.CustomFormatter = elem + } + } + err := json.Unmarshal([]byte(jsonconfig), el) if err != nil { return err @@ -69,9 +78,16 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { return nil } + msg := "" + if el.UseCustomFormatter { + msg = el.CustomFormatter(lm) + } else { + msg = el.Format(lm) + } + idx := LogDocument{ Timestamp: lm.When.Format(time.RFC3339), - Msg: el.Format(lm), + Msg: msg, } body, err := json.Marshal(idx) From c2471b22ad04bf1623aab8fc3dc8f2d5f6461a88 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:54:55 +0100 Subject: [PATCH 08/35] Remove ineffectual assignments Removed 3 lines due to warning from test suite saying these lines had innefectual assignments --- pkg/logs/alils/alils.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 2300f8f8..183d9b24 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -130,17 +130,14 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { if len(strs) == 2 { pos := strings.LastIndex(strs[0], " ") topic = strs[0][pos+1 : len(strs[0])] - content = strs[0][0:pos] + strs[1] lg = c.groupMap[topic] } // send to empty Topic if lg == nil { - content = lm.Msg lg = c.group[0] } } else { - content = lm.Msg lg = c.group[0] } From d24f861629303f96f1400223e0e31f4de63b7686 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 21:00:58 +0100 Subject: [PATCH 09/35] empty commit to restart CI From 2b39ff78374f3b99c4c874b7cf71b1f50e058e7e Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:00:45 +0100 Subject: [PATCH 10/35] New opts formatter working for console --- pkg/logs/conn.go | 32 ++++++++++++++++---------------- pkg/logs/console.go | 31 +++++++++++++++++++------------ pkg/logs/file.go | 16 +++++++++------- pkg/logs/jianliao.go | 18 ++++++++++-------- pkg/logs/log.go | 36 ++++++++++++++++++++++++++---------- pkg/logs/multifile.go | 22 ++++++++++++---------- pkg/logs/slack.go | 17 +++++++++-------- pkg/logs/smtp.go | 18 ++++++++++-------- 8 files changed, 111 insertions(+), 79 deletions(-) diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index e11909a0..55cbecdd 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -18,20 +18,20 @@ import ( "encoding/json" "io" "net" + + "github.com/astaxie/beego/pkg/common" ) // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. @@ -47,13 +47,13 @@ func (c *connWriter) Format(lm *LogMsg) string { // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - c.UseCustomFormatter = true - c.CustomFormatter = elem - } - } +func (c *connWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // c.UseCustomFormatter = true + // c.CustomFormatter = elem + // } + // } return json.Unmarshal([]byte(jsonConfig), c) } diff --git a/pkg/logs/console.go b/pkg/logs/console.go index a928de7d..55958008 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -19,6 +19,8 @@ import ( "os" "strings" + "github.com/astaxie/beego/pkg/common" + "github.com/shiena/ansicolor" ) @@ -47,11 +49,10 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + customFormatter func(*LogMsg) string + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color } func (c *consoleWriter) Format(lm *LogMsg) string { @@ -80,11 +81,16 @@ func NewConsole() Logger { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - c.UseCustomFormatter = true - c.CustomFormatter = elem +// func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { +func (c *consoleWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + c.customFormatter = formatter } } @@ -107,10 +113,11 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { msg := "" - if c.UseCustomFormatter { - msg = c.CustomFormatter(lm) + if c.customFormatter != nil { + msg = c.customFormatter(lm) } else { msg = c.Format(lm) + } c.lg.writeln(msg) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 4576e19d..0324486e 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -27,6 +27,8 @@ import ( "strings" "sync" "time" + + "github.com/astaxie/beego/pkg/common" ) // fileLogWriter implements LoggerInterface. @@ -107,13 +109,13 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - w.UseCustomFormatter = true - w.CustomFormatter = elem - } - } +func (w *fileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // w.UseCustomFormatter = true + // w.CustomFormatter = elem + // } + // } err := json.Unmarshal([]byte(jsonConfig), w) if err != nil { diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 9877bed6..8daa8015 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/url" + + "github.com/astaxie/beego/pkg/common" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -25,15 +27,15 @@ func newJLWriter() Logger { } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - s.UseCustomFormatter = true - s.CustomFormatter = elem - } - } +func (s *JLWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // s.UseCustomFormatter = true + // s.CustomFormatter = elem + // } + // } - return json.Unmarshal([]byte(jsonconfig), s) + return json.Unmarshal([]byte(jsonConfig), s) } func (s *JLWriter) Format(lm *LogMsg) string { diff --git a/pkg/logs/log.go b/pkg/logs/log.go index fd8fca63..9529c865 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -38,10 +38,13 @@ import ( "log" "os" "path" + "reflect" "runtime" "strings" "sync" "time" + + "github.com/astaxie/beego/pkg/common" ) // RFC5424 log message levels. @@ -84,7 +87,7 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { - Init(config string, LogFormatter ...func(*LogMsg) string) error + Init(config string, opts ...common.SimpleKV) error WriteMsg(lm *LogMsg) error Format(lm *LogMsg) string Destroy() @@ -210,7 +213,7 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { // Global formatter overrides the default set formatter // but not adapter specific formatters set with logs.SetLoggerWithOpts() if bl.globalFormatter != nil { - err = lg.Init(config, bl.globalFormatter) + err = lg.Init(config) } else { err = lg.Init(config) } @@ -401,9 +404,21 @@ func (bl *BeeLogger) startLogger() { } } +// Get the formatter from the opts common.SimpleKV structure +// Looks for a key: "formatter" with value: func(*LogMsg) string +func GetFormatter(opts common.SimpleKV) (func(*LogMsg) string, error) { + if strings.ToLower(opts.Key.(string)) == "formatter" { + formatterInterface := reflect.ValueOf(opts.Value).Interface() + formatterFunc := formatterInterface.(func(*LogMsg) string) + return formatterFunc, nil + } + + return nil, fmt.Errorf("no \"formatter\" key given in simpleKV") +} + // SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON // such as: {"interval":360} -func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { +func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts common.SimpleKV, configs ...string) error { config := append(configs, "{}")[0] for _, l := range bl.outputs { if l.name == adapterName { @@ -416,12 +431,12 @@ func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*L return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) } - if formatterFunc == nil { - return fmt.Errorf("No formatter set for %s log adapter", adapterName) + if opts.Key == nil { + return fmt.Errorf("No SimpleKV struct set for %s log adapter", adapterName) } lg := logAdapter() - err := lg.Init(config, formatterFunc) + err := lg.Init(config, opts) if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -436,21 +451,22 @@ func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*L } // SetLogger provides a given logger adapter into BeeLogger with config string. -func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { +func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts common.SimpleKV, configs ...string) error { bl.lock.Lock() defer bl.lock.Unlock() if !bl.init { bl.outputs = []*nameLogger{} bl.init = true } - return bl.setLoggerWithOpts(adapterName, formatterFunc, configs...) + return bl.setLoggerWithOpts(adapterName, opts, configs...) } // SetLoggerWIthOpts sets a given log adapter with a custom log adapter. // Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc} // where FormatFunc has the signature func(*LogMsg) string -func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { - err := beeLogger.SetLoggerWithOpts(adapter, formatterFunc, config...) +// func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { +func SetLoggerWithOpts(adapter string, config []string, opts common.SimpleKV) error { + err := beeLogger.SetLoggerWithOpts(adapter, opts, config...) if err != nil { log.Fatal(err) } diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index bcd4dd4e..720f5125 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -16,6 +16,8 @@ package logs import ( "encoding/json" + + "github.com/astaxie/beego/pkg/common" ) // A filesLogWriter manages several fileLogWriter @@ -46,16 +48,16 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], // } -func (f *multiFileLogWriter) Init(config string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - f.UseCustomFormatter = true - f.CustomFormatter = elem - } - } +func (f *multiFileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // f.UseCustomFormatter = true + // f.CustomFormatter = elem + // } + // } writer := newFileWriter().(*fileLogWriter) - err := writer.Init(config) + err := writer.Init(jsonConfig) if err != nil { return err } @@ -63,10 +65,10 @@ func (f *multiFileLogWriter) Init(config string, LogFormatter ...func(*LogMsg) s f.writers[LevelDebug+1] = writer //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(config), f) + json.Unmarshal([]byte(jsonConfig), f) jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(config), &jsonMap) + json.Unmarshal([]byte(jsonConfig), &jsonMap) for i := LevelEmergency; i < LevelDebug+1; i++ { for _, v := range f.Separate { diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index 9407b48a..0fc75149 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/url" + + "github.com/astaxie/beego/pkg/common" ) // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -25,15 +27,14 @@ func (s *SLACKWriter) Format(lm *LogMsg) string { } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - s.UseCustomFormatter = true - s.CustomFormatter = elem - } - } +func (s *SLACKWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // if elem != nil { + // s.UseCustomFormatter = true + // s.CustomFormatter = elem + // } + // } - return json.Unmarshal([]byte(jsonconfig), s) + return json.Unmarshal([]byte(jsonConfig), s) } // WriteMsg write message in smtp writer. diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index b81be68f..17148812 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -21,6 +21,8 @@ import ( "net" "net/smtp" "strings" + + "github.com/astaxie/beego/pkg/common" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -52,15 +54,15 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - s.UseCustomFormatter = true - s.CustomFormatter = elem - } - } +func (s *SMTPWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // s.UseCustomFormatter = true + // s.CustomFormatter = elem + // } + // } - return json.Unmarshal([]byte(jsonconfig), s) + return json.Unmarshal([]byte(jsonConfig), s) } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { From 8178f035a08231ec7a04ab7a825ef1cacffac4d4 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:18:28 +0100 Subject: [PATCH 11/35] Custom formatting opts implementation --- pkg/logs/alils/alils.go | 36 ++++++++++++++++++++---------------- pkg/logs/conn.go | 40 ++++++++++++++++++++++++++-------------- pkg/logs/console.go | 1 - pkg/logs/es/es.go | 28 ++++++++++++++++------------ pkg/logs/file.go | 24 ++++++++++++++---------- pkg/logs/jianliao.go | 40 +++++++++++++++++++++++++--------------- pkg/logs/multifile.go | 24 +++++++++++++----------- pkg/logs/smtp.go | 19 +++++++++++-------- 8 files changed, 125 insertions(+), 87 deletions(-) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 183d9b24..425071f8 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -5,6 +5,7 @@ import ( "strings" "sync" + "github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/logs" "github.com/gogo/protobuf/proto" ) @@ -32,13 +33,12 @@ type Config struct { // aliLSWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex - UseCustomFormatter bool - CustomFormatter func(*logs.LogMsg) string + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + customFormatter func(*logs.LogMsg) string Config } @@ -50,15 +50,17 @@ func NewAliLS() logs.Logger { } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string, LogFormatter ...func(*logs.LogMsg) string) (err error) { +func (c *aliLSWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - for _, elem := range LogFormatter { - if elem != nil { - c.UseCustomFormatter = true - c.CustomFormatter = elem + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := logs.GetFormatter(elem) + if err != nil { + return err + } + c.customFormatter = formatter } } - json.Unmarshal([]byte(jsonConfig), c) if c.FlushWhen > CacheSize { @@ -72,11 +74,13 @@ func (c *aliLSWriter) Init(jsonConfig string, LogFormatter ...func(*logs.LogMsg) AccessKeySecret: c.KeySecret, } - c.store, err = prj.GetLogStore(c.LogStore) + store, err := prj.GetLogStore(c.LogStore) if err != nil { return err } + c.store = store + // Create default Log Group c.group = append(c.group, &LogGroup{ Topic: proto.String(""), @@ -141,8 +145,8 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { lg = c.group[0] } - if c.UseCustomFormatter { - content = c.CustomFormatter(lm) + if c.customFormatter != nil { + content = c.customFormatter(lm) } else { content = c.Format(lm) } diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index 55cbecdd..9a520bda 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -25,13 +25,14 @@ import ( // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + customFormatter func(*LogMsg) string + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. @@ -48,12 +49,16 @@ func (c *connWriter) Format(lm *LogMsg) string { // Init initializes a connection writer with json config. // json config only needs they "level" key func (c *connWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // c.UseCustomFormatter = true - // c.CustomFormatter = elem - // } - // } + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + c.customFormatter = formatter + } + } return json.Unmarshal([]byte(jsonConfig), c) } @@ -75,7 +80,14 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { defer c.innerWriter.Close() } - msg := c.Format(lm) + msg := "" + if c.customFormatter != nil { + msg = c.customFormatter(lm) + } else { + msg = c.Format(lm) + + } + _, err := c.lg.writeln(msg) if err != nil { return err diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 55958008..34114e4a 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -81,7 +81,6 @@ func NewConsole() Logger { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -// func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { func (c *consoleWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { for _, elem := range opts { diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 4dfc4160..dc9304c8 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -12,6 +12,7 @@ import ( "github.com/elastic/go-elasticsearch/v6" "github.com/elastic/go-elasticsearch/v6/esapi" + "github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/logs" ) @@ -31,10 +32,9 @@ func NewES() logs.Logger { // import _ "github.com/astaxie/beego/logs/es" type esLogger struct { *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*logs.LogMsg) string + DSN string `json:"dsn"` + Level int `json:"level"` + customFormatter func(*logs.LogMsg) string } func (el *esLogger) Format(lm *logs.LogMsg) string { @@ -42,15 +42,19 @@ func (el *esLogger) Format(lm *logs.LogMsg) string { } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string, LogFormatter ...func(*logs.LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - el.UseCustomFormatter = true - el.CustomFormatter = elem +func (el *esLogger) Init(jsonConfig string, opts ...common.SimpleKV) error { + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := logs.GetFormatter(elem) + if err != nil { + return err + } + el.customFormatter = formatter } } - err := json.Unmarshal([]byte(jsonconfig), el) + err := json.Unmarshal([]byte(jsonConfig), el) if err != nil { return err } @@ -79,8 +83,8 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { } msg := "" - if el.UseCustomFormatter { - msg = el.CustomFormatter(lm) + if el.customFormatter != nil { + msg = el.customFormatter(lm) } else { msg = el.Format(lm) } diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 0324486e..42148c3a 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -62,8 +62,7 @@ type fileLogWriter struct { hourlyOpenDate int hourlyOpenTime time.Time - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + customFormatter func(*LogMsg) string Rotate bool `json:"rotate"` @@ -110,12 +109,16 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "perm":"0600" // } func (w *fileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // w.UseCustomFormatter = true - // w.CustomFormatter = elem - // } - // } + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + w.customFormatter = formatter + } + } err := json.Unmarshal([]byte(jsonConfig), w) if err != nil { @@ -166,8 +169,9 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { } hd, d, h := formatTimeHeader(lm.When) msg := "" - if w.UseCustomFormatter { - msg = w.CustomFormatter(lm) + + if w.customFormatter != nil { + msg = w.customFormatter(lm) } else { msg = w.Format(lm) } diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 8daa8015..81d0195b 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -11,14 +11,13 @@ import ( // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` + customFormatter func(*LogMsg) string } // newJLWriter creates jiaoliao writer. @@ -28,12 +27,15 @@ func newJLWriter() Logger { // Init JLWriter with json config string func (s *JLWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // s.UseCustomFormatter = true - // s.CustomFormatter = elem - // } - // } + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + s.customFormatter = formatter + } + } return json.Unmarshal([]byte(jsonConfig), s) } @@ -49,7 +51,15 @@ func (s *JLWriter) WriteMsg(lm *LogMsg) error { return nil } - text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) + text := "" + + if s.customFormatter != nil { + text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.customFormatter(lm)) + } else { + text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) + + } + form := url.Values{} form.Add("authorName", s.AuthorName) form.Add("title", s.Title) diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index 720f5125..c1b7cfdd 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -26,11 +26,10 @@ import ( // and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log // the rotate attribute also acts like fileLogWriter type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` + customFormatter func(*LogMsg) string } var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} @@ -49,12 +48,15 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // } func (f *multiFileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // f.UseCustomFormatter = true - // f.CustomFormatter = elem - // } - // } + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + f.customFormatter = formatter + } + } writer := newFileWriter().(*fileLogWriter) err := writer.Init(jsonConfig) diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index 17148812..9b67e343 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -34,8 +34,7 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + customFormatter func(*LogMsg) string } // NewSMTPWriter creates the smtp writer. @@ -55,12 +54,16 @@ func newSMTPWriter() Logger { // "level":LevelError // } func (s *SMTPWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // s.UseCustomFormatter = true - // s.CustomFormatter = elem - // } - // } + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + s.customFormatter = formatter + } + } return json.Unmarshal([]byte(jsonConfig), s) } From e0a934af1d8bb4f946e881bd330ca578b1dc41d7 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:24:57 +0100 Subject: [PATCH 12/35] empty commit to restart CI From 6684924e995a5a15a513e667ded09e4796bf6aa6 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:30:41 +0100 Subject: [PATCH 13/35] empty commit to restart CI again From 0189e6329a4e1700ea589de4eebe29e8624b422d Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:47:28 +0100 Subject: [PATCH 14/35] Add global logging override --- pkg/logs/log.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 9529c865..e18ea95b 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -213,7 +213,7 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { // Global formatter overrides the default set formatter // but not adapter specific formatters set with logs.SetLoggerWithOpts() if bl.globalFormatter != nil { - err = lg.Init(config) + err = lg.Init(config, common.SimpleKV{Key: "formatter", Value: bl.globalFormatter}) } else { err = lg.Init(config) } From 185d55eb4638c6432a73d7ae94b3e3654ce134e6 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 1 Sep 2020 21:25:29 +0800 Subject: [PATCH 15/35] adapt config --- pkg/adapter/config/adapter.go | 193 +++++++++++++++++++++++ pkg/adapter/config/config.go | 151 ++++++++++++++++++ pkg/adapter/config/config_test.go | 55 +++++++ pkg/adapter/config/env/env.go | 50 ++++++ pkg/adapter/config/env/env_test.go | 75 +++++++++ pkg/adapter/config/fake.go | 25 +++ pkg/adapter/config/ini_test.go | 190 +++++++++++++++++++++++ pkg/adapter/config/json.go | 19 +++ pkg/adapter/config/json_test.go | 222 +++++++++++++++++++++++++++ pkg/adapter/config/xml/xml.go | 34 ++++ pkg/adapter/config/xml/xml_test.go | 125 +++++++++++++++ pkg/adapter/config/yaml/yaml.go | 34 ++++ pkg/adapter/config/yaml/yaml_test.go | 115 ++++++++++++++ 13 files changed, 1288 insertions(+) create mode 100644 pkg/adapter/config/adapter.go create mode 100644 pkg/adapter/config/config.go create mode 100644 pkg/adapter/config/config_test.go create mode 100644 pkg/adapter/config/env/env.go create mode 100644 pkg/adapter/config/env/env_test.go create mode 100644 pkg/adapter/config/fake.go create mode 100644 pkg/adapter/config/ini_test.go create mode 100644 pkg/adapter/config/json.go create mode 100644 pkg/adapter/config/json_test.go create mode 100644 pkg/adapter/config/xml/xml.go create mode 100644 pkg/adapter/config/xml/xml_test.go create mode 100644 pkg/adapter/config/yaml/yaml.go create mode 100644 pkg/adapter/config/yaml/yaml_test.go diff --git a/pkg/adapter/config/adapter.go b/pkg/adapter/config/adapter.go new file mode 100644 index 00000000..f74b3ff9 --- /dev/null +++ b/pkg/adapter/config/adapter.go @@ -0,0 +1,193 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "context" + + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +type newToOldConfigerAdapter struct { + delegate config.Configer +} + +func (c *newToOldConfigerAdapter) Set(key, val string) error { + return c.delegate.Set(context.Background(), key, val) +} + +func (c *newToOldConfigerAdapter) String(key string) string { + res, _ := c.delegate.String(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Strings(key string) []string { + res, _ := c.delegate.Strings(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Int(key string) (int, error) { + return c.delegate.Int(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Int64(key string) (int64, error) { + return c.delegate.Int64(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Bool(key string) (bool, error) { + return c.delegate.Bool(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Float(key string) (float64, error) { + return c.delegate.Float(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) DefaultString(key string, defaultVal string) string { + return c.delegate.DefaultString(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultStrings(key string, defaultVal []string) []string { + return c.delegate.DefaultStrings(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt(key string, defaultVal int) int { + return c.delegate.DefaultInt(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt64(key string, defaultVal int64) int64 { + return c.delegate.DefaultInt64(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultBool(key string, defaultVal bool) bool { + return c.delegate.DefaultBool(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultFloat(key string, defaultVal float64) float64 { + return c.delegate.DefaultFloat(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DIY(key string) (interface{}, error) { + return c.delegate.DIY(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) GetSection(section string) (map[string]string, error) { + return c.delegate.GetSection(context.Background(), section) +} + +func (c *newToOldConfigerAdapter) SaveConfigFile(filename string) error { + return c.delegate.SaveConfigFile(context.Background(), filename) +} + +type oldToNewConfigerAdapter struct { + delegate Configer +} + +func (o *oldToNewConfigerAdapter) Set(ctx context.Context, key, val string) error { + return o.delegate.Set(key, val) +} + +func (o *oldToNewConfigerAdapter) String(ctx context.Context, key string) (string, error) { + return o.delegate.String(key), nil +} + +func (o *oldToNewConfigerAdapter) Strings(ctx context.Context, key string) ([]string, error) { + return o.delegate.Strings(key), nil +} + +func (o *oldToNewConfigerAdapter) Int(ctx context.Context, key string) (int, error) { + return o.delegate.Int(key) +} + +func (o *oldToNewConfigerAdapter) Int64(ctx context.Context, key string) (int64, error) { + return o.delegate.Int64(key) +} + +func (o *oldToNewConfigerAdapter) Bool(ctx context.Context, key string) (bool, error) { + return o.delegate.Bool(key) +} + +func (o *oldToNewConfigerAdapter) Float(ctx context.Context, key string) (float64, error) { + return o.delegate.Float(key) +} + +func (o *oldToNewConfigerAdapter) DefaultString(ctx context.Context, key string, defaultVal string) string { + return o.delegate.DefaultString(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + return o.delegate.DefaultStrings(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt(ctx context.Context, key string, defaultVal int) int { + return o.delegate.DefaultInt(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + return o.delegate.DefaultInt64(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + return o.delegate.DefaultBool(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + return o.delegate.DefaultFloat(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DIY(ctx context.Context, key string) (interface{}, error) { + return o.delegate.DIY(key) +} + +func (o *oldToNewConfigerAdapter) GetSection(ctx context.Context, section string) (map[string]string, error) { + return o.delegate.GetSection(section) +} + +func (o *oldToNewConfigerAdapter) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + return errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) Sub(ctx context.Context, key string) (config.Configer, error) { + return nil, errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) OnChange(ctx context.Context, key string, fn func(value string)) { + // do nothing +} + +func (o *oldToNewConfigerAdapter) SaveConfigFile(ctx context.Context, filename string) error { + return o.delegate.SaveConfigFile(filename) +} + +type oldToNewConfigAdapter struct { + delegate Config +} + +func (o *oldToNewConfigAdapter) Parse(key string) (config.Configer, error) { + old, err := o.delegate.Parse(key) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} + +func (o *oldToNewConfigAdapter) ParseData(data []byte) (config.Configer, error) { + old, err := o.delegate.ParseData(data) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} diff --git a/pkg/adapter/config/config.go b/pkg/adapter/config/config.go new file mode 100644 index 00000000..c870a15a --- /dev/null +++ b/pkg/adapter/config/config.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +// Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +// More docs http://beego.me/docs/module/config.md +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error // support section::key type in given key when using ini type. + String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string // get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string // get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + config.Register(name, &oldToNewConfigAdapter{delegate: adapter}) +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + cfg, err := config.NewConfig(adapterName, filename) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + cfg, err := config.NewConfigData(adapterName, data) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + return config.ExpandValueEnvForMap(m) +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) string { + return config.ExpandValueEnv(value) +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + return config.ParseBool(val) +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + return config.ToString(x) +} diff --git a/pkg/adapter/config/config_test.go b/pkg/adapter/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/adapter/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/adapter/config/env/env.go b/pkg/adapter/config/env/env.go new file mode 100644 index 00000000..77d7b53c --- /dev/null +++ b/pkg/adapter/config/env/env.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config/env" +) + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + return env.Get(key, defVal) +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + return env.MustGet(key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + return env.MustSet(key, value) +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + return env.GetAll() +} diff --git a/pkg/adapter/config/env/env_test.go b/pkg/adapter/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/adapter/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/adapter/config/fake.go b/pkg/adapter/config/fake.go new file mode 100644 index 00000000..fac96b41 --- /dev/null +++ b/pkg/adapter/config/fake.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + new := config.NewFakeConfig() + return &newToOldConfigerAdapter{delegate: new} +} diff --git a/pkg/adapter/config/ini_test.go b/pkg/adapter/config/ini_test.go new file mode 100644 index 00000000..ffcdb294 --- /dev/null +++ b/pkg/adapter/config/ini_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(k) + case int64: + value, err = iniconf.Int64(k) + case float64: + value, err = iniconf.Float(k) + case bool: + value, err = iniconf.Bool(k) + case []string: + value = iniconf.Strings(k) + case string: + value = iniconf.String(k) + default: + value, err = iniconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if iniconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/adapter/config/json.go b/pkg/adapter/config/json.go new file mode 100644 index 00000000..d0fe4d09 --- /dev/null +++ b/pkg/adapter/config/json.go @@ -0,0 +1,19 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/json" +) diff --git a/pkg/adapter/config/json_test.go b/pkg/adapter/config/json_test.go new file mode 100644 index 00000000..16f42409 --- /dev/null +++ b/pkg/adapter/config/json_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY("rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(k) + case int64: + value, err = jsonconf.Int64(k) + case float64: + value, err = jsonconf.Float(k) + case bool: + value, err = jsonconf.Bool(k) + case []string: + value = jsonconf.Strings(k) + case string: + value = jsonconf.String(k) + default: + value, err = jsonconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if jsonconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY("database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val := jsonconf.String("unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool("unknown", true) { + t.Error("unknown keys with default value wrong") + } +} diff --git a/pkg/adapter/config/xml/xml.go b/pkg/adapter/config/xml/xml.go new file mode 100644 index 00000000..f96cdcd6 --- /dev/null +++ b/pkg/adapter/config/xml/xml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +// More docs http://beego.me/docs/module/config.md +package xml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/xml" +) diff --git a/pkg/adapter/config/xml/xml_test.go b/pkg/adapter/config/xml/xml_test.go new file mode 100644 index 00000000..122c5027 --- /dev/null +++ b/pkg/adapter/config/xml/xml_test.go @@ -0,0 +1,125 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/pkg/adapter/config" +) + +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if xmlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } +} diff --git a/pkg/adapter/config/yaml/yaml.go b/pkg/adapter/config/yaml/yaml.go new file mode 100644 index 00000000..bc2398e9 --- /dev/null +++ b/pkg/adapter/config/yaml/yaml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +// More docs http://beego.me/docs/module/config.md +package yaml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/yaml" +) diff --git a/pkg/adapter/config/yaml/yaml_test.go b/pkg/adapter/config/yaml/yaml_test.go new file mode 100644 index 00000000..e4e309a2 --- /dev/null +++ b/pkg/adapter/config/yaml/yaml_test.go @@ -0,0 +1,115 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/pkg/adapter/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + if yamlconf.String("appname") != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if yamlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} From 78d91062c911d53ec9bc5ce278ce5dee14020e15 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 1 Sep 2020 22:16:49 +0800 Subject: [PATCH 16/35] Adapt new API to old API: httplib --- pkg/adapter/httplib/httplib.go | 300 ++++++++++++++++++++++++++++ pkg/adapter/httplib/httplib_test.go | 286 ++++++++++++++++++++++++++ 2 files changed, 586 insertions(+) create mode 100644 pkg/adapter/httplib/httplib.go create mode 100644 pkg/adapter/httplib/httplib_test.go diff --git a/pkg/adapter/httplib/httplib.go b/pkg/adapter/httplib/httplib.go new file mode 100644 index 00000000..d2ef36c1 --- /dev/null +++ b/pkg/adapter/httplib/httplib.go @@ -0,0 +1,300 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/client/httplib" +) + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + httplib.SetDefaultSetting(httplib.BeegoHTTPSettings(setting)) +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + return &BeegoHTTPRequest{ + delegate: httplib.NewBeegoRequest(rawurl, method), + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings httplib.BeegoHTTPSettings + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + delegate *httplib.BeegoHTTPRequest +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.delegate.GetRequest() +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.delegate.Setting(httplib.BeegoHTTPSettings(setting)) + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.delegate.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.delegate.SetEnableCookie(enable) + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.delegate.SetUserAgent(useragent) + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.delegate.Debug(isdebug) + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.delegate.Retries(times) + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.delegate.RetryDelay(delay) + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.delegate.DumpBody(isdump) + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.delegate.DumpRequest() +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.delegate.SetTimeout(connectTimeout, readWriteTimeout) + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.delegate.SetTLSClientConfig(config) + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.delegate.Header(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.delegate.SetHost(host) + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + b.delegate.SetProtocolVersion(vers) + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.delegate.SetCookie(cookie) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.delegate.SetTransport(transport) + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.delegate.SetProxy(proxy) + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.delegate.SetCheckRedirect(redirect) + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + b.delegate.Param(key, value) + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.delegate.PostFile(formname, filename) + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + b.delegate.Body(data) + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.XMLBody(obj) + return b, err +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.YAMLBody(obj) + return b, err +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.JSONBody(obj) + return b, err +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + return b.delegate.DoRequest() +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + return b.delegate.String() +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + return b.delegate.Bytes() +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + return b.delegate.ToFile(filename) +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + return b.delegate.ToJSON(v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + return b.delegate.ToXML(v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + return b.delegate.ToYAML(v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.delegate.Response() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return httplib.TimeoutDialer(cTimeout, rwTimeout) +} diff --git a/pkg/adapter/httplib/httplib_test.go b/pkg/adapter/httplib/httplib_test.go new file mode 100644 index 00000000..e7605c87 --- /dev/null +++ b/pkg/adapter/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.SetCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + }) + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +// func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +// } + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} From 3bf5cde38c840383b75ab7873fdb062aa2abe7ad Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 20:36:53 +0800 Subject: [PATCH 17/35] adapt context --- pkg/adapter/context/acceptencoder.go | 45 +++++ pkg/adapter/context/context.go | 146 ++++++++++++++ pkg/adapter/context/input.go | 282 +++++++++++++++++++++++++++ pkg/adapter/context/output.go | 154 +++++++++++++++ pkg/adapter/context/renderer.go | 9 + pkg/adapter/context/response.go | 26 +++ 6 files changed, 662 insertions(+) create mode 100644 pkg/adapter/context/acceptencoder.go create mode 100644 pkg/adapter/context/context.go create mode 100644 pkg/adapter/context/input.go create mode 100644 pkg/adapter/context/output.go create mode 100644 pkg/adapter/context/renderer.go create mode 100644 pkg/adapter/context/response.go diff --git a/pkg/adapter/context/acceptencoder.go b/pkg/adapter/context/acceptencoder.go new file mode 100644 index 00000000..e578de45 --- /dev/null +++ b/pkg/adapter/context/acceptencoder.go @@ -0,0 +1,45 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + context.InitGzip(minLength, compressLevel, methods) +} + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return context.WriteFile(encoding, writer, file) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + return context.WriteBody(encoding, writer, content) +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + return context.ParseEncoding(r) +} diff --git a/pkg/adapter/context/context.go b/pkg/adapter/context/context.go new file mode 100644 index 00000000..f9d8c624 --- /dev/null +++ b/pkg/adapter/context/context.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "net" + "net/http" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// commonly used mime-types +const ( + ApplicationJSON = context.ApplicationJSON + ApplicationXML = context.ApplicationXML + ApplicationYAML = context.ApplicationYAML + TextXML = context.TextXML +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return (*Context)(context.NewContext()) +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context context.Context + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + (*context.Context)(ctx).Reset(rw, r) +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + (*context.Context)(ctx).Redirect(status, localurl) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + (*context.Context)(ctx).Abort(status, body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + (*context.Context)(ctx).WriteString(content) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return (*context.Context)(ctx).GetCookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + (*context.Context)(ctx).SetCookie(name, value, others) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + return (*context.Context)(ctx).GetSecureCookie(Secret, key) +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*context.Context)(ctx).SetSecureCookie(Secret, name, value, others) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + return (*context.Context)(ctx).XSRFToken(key, expire) +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + return (*context.Context)(ctx).CheckXSRFCookie() +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + (*context.Context)(ctx).RenderMethodResult(result) +} + +// Response is a wrapper for the http.ResponseWriter +// started set to true if response was written to then don't execute other handler +type Response context.Response + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + return (*context.Response)(r).Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + (*context.Response)(r).WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return (*context.Response)(r).Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + (*context.Response)(r).Flush() +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + return (*context.Response)(r).CloseNotify() +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + return (*context.Response)(r).Pusher() +} diff --git a/pkg/adapter/context/input.go b/pkg/adapter/context/input.go new file mode 100644 index 00000000..a1d08855 --- /dev/null +++ b/pkg/adapter/context/input.go @@ -0,0 +1,282 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput context.BeegoInput + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return (*BeegoInput)(context.NewInput()) +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + (*context.BeegoInput)(input).Reset((*context.Context)(ctx)) +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return (*context.BeegoInput)(input).Protocol() +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return (*context.BeegoInput)(input).URL() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return (*context.BeegoInput)(input).Site() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + return (*context.BeegoInput)(input).Scheme() +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return (*context.BeegoInput)(input).Domain() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + return (*context.BeegoInput)(input).Host() +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return (*context.BeegoInput)(input).Method() +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return (*context.BeegoInput)(input).Is(method) +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return (*context.BeegoInput)(input).IsGet() +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return (*context.BeegoInput)(input).IsPost() +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return (*context.BeegoInput)(input).IsHead() +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return (*context.BeegoInput)(input).IsOptions() +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return (*context.BeegoInput)(input).IsPut() +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return (*context.BeegoInput)(input).IsDelete() +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return (*context.BeegoInput)(input).IsPatch() +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return (*context.BeegoInput)(input).IsAjax() +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return (*context.BeegoInput)(input).IsSecure() +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return (*context.BeegoInput)(input).IsWebsocket() +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return (*context.BeegoInput)(input).IsUpload() +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return (*context.BeegoInput)(input).AcceptsHTML() +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return (*context.BeegoInput)(input).AcceptsXML() +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return (*context.BeegoInput)(input).AcceptsJSON() +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return (*context.BeegoInput)(input).AcceptsYAML() +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + return (*context.BeegoInput)(input).IP() +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + return (*context.BeegoInput)(input).Proxy() +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return (*context.BeegoInput)(input).Referer() +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return (*context.BeegoInput)(input).Refer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + return (*context.BeegoInput)(input).SubDomains() +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + return (*context.BeegoInput)(input).Port() +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return (*context.BeegoInput)(input).UserAgent() +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return (*context.BeegoInput)(input).ParamsLen() +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + return (*context.BeegoInput)(input).Param(key) +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + return (*context.BeegoInput)(input).Params() +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + (*context.BeegoInput)(input).SetParam(key, val) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + (*context.BeegoInput)(input).ResetParams() +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + return (*context.BeegoInput)(input).Query(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return (*context.BeegoInput)(input).Header(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + return (*context.BeegoInput)(input).Cookie(key) +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return (*context.BeegoInput)(input).Session(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + return (*context.BeegoInput)(input).CopyBody(MaxMemory) +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + return (*context.BeegoInput)(input).Data() +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + return (*context.BeegoInput)(input).GetData(key) +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + (*context.BeegoInput)(input).SetData(key, val) +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + return (*context.BeegoInput)(input).ParseFormOrMulitForm(maxMemory) +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + return (*context.BeegoInput)(input).Bind(dest, key) +} diff --git a/pkg/adapter/context/output.go b/pkg/adapter/context/output.go new file mode 100644 index 00000000..8e2a7f7d --- /dev/null +++ b/pkg/adapter/context/output.go @@ -0,0 +1,154 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput context.BeegoOutput + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return (*BeegoOutput)(context.NewOutput()) +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + (*context.BeegoOutput)(output).Reset((*context.Context)(ctx)) +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + (*context.BeegoOutput)(output).Header(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + return (*context.BeegoOutput)(output).Body(content) +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + (*context.BeegoOutput)(output).Cookie(name, value, others) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + return (*context.BeegoOutput)(output).JSON(data, hasIndent, encoding) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + return (*context.BeegoOutput)(output).YAML(data) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).JSONP(data, hasIndent) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).XML(data, hasIndent) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + (*context.BeegoOutput)(output).ServeFormatted(data, hasIndent, hasEncode...) +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + (*context.BeegoOutput)(output).Download(file, filename...) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + (*context.BeegoOutput)(output).ContentType(ext) +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + (*context.BeegoOutput)(output).SetStatus(status) +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return (*context.BeegoOutput)(output).IsCachable() +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return (*context.BeegoOutput)(output).IsEmpty() +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return (*context.BeegoOutput)(output).IsOk() +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return (*context.BeegoOutput)(output).IsSuccessful() +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return (*context.BeegoOutput)(output).IsRedirect() +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return (*context.BeegoOutput)(output).IsForbidden() +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return (*context.BeegoOutput)(output).IsNotFound() +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return (*context.BeegoOutput)(output).IsClientError() +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return (*context.BeegoOutput)(output).IsServerError() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + (*context.BeegoOutput)(output).Session(name, value) +} diff --git a/pkg/adapter/context/renderer.go b/pkg/adapter/context/renderer.go new file mode 100644 index 00000000..7e352007 --- /dev/null +++ b/pkg/adapter/context/renderer.go @@ -0,0 +1,9 @@ +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// Renderer defines an http response renderer +type Renderer context.Renderer + diff --git a/pkg/adapter/context/response.go b/pkg/adapter/context/response.go new file mode 100644 index 00000000..24e196a4 --- /dev/null +++ b/pkg/adapter/context/response.go @@ -0,0 +1,26 @@ +package context + +import ( + "net/http" + "strconv" +) + +const ( + // BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + // NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} From 8fc4f8847c4f9d887605ab4c56461a2feb0549de Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 20:43:35 +0800 Subject: [PATCH 18/35] adapt grace and metric --- pkg/adapter/context/renderer.go | 1 - pkg/adapter/grace/grace.go | 96 ++++++++++++++++++++++++++ pkg/adapter/grace/server.go | 48 +++++++++++++ pkg/adapter/metric/prometheus.go | 99 +++++++++++++++++++++++++++ pkg/adapter/metric/prometheus_test.go | 42 ++++++++++++ 5 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 pkg/adapter/grace/grace.go create mode 100644 pkg/adapter/grace/server.go create mode 100644 pkg/adapter/metric/prometheus.go create mode 100644 pkg/adapter/metric/prometheus_test.go diff --git a/pkg/adapter/context/renderer.go b/pkg/adapter/context/renderer.go index 7e352007..763fb9c4 100644 --- a/pkg/adapter/context/renderer.go +++ b/pkg/adapter/context/renderer.go @@ -6,4 +6,3 @@ import ( // Renderer defines an http response renderer type Renderer context.Renderer - diff --git a/pkg/adapter/grace/grace.go b/pkg/adapter/grace/grace.go new file mode 100644 index 00000000..67cd4a1e --- /dev/null +++ b/pkg/adapter/grace/grace.go @@ -0,0 +1,96 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = grace.DefaultTimeout + +) + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + return (*Server)(grace.NewServer(addr, handler)) +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/adapter/grace/server.go b/pkg/adapter/grace/server.go new file mode 100644 index 00000000..31c13f18 --- /dev/null +++ b/pkg/adapter/grace/server.go @@ -0,0 +1,48 @@ +package grace + +import ( + "os" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +// Server embedded http.Server +type Server grace.Server + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + return (*grace.Server)(srv).Serve() +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + return (*grace.Server)(srv).ListenAndServe() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + return (*grace.Server)(srv).ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) error { + return (*grace.Server)(srv).ListenAndServeMutualTLS(certFile, keyFile, trustFile) +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) error { + return (*grace.Server)(srv).RegisterSignalHook(ppFlag, sig, f) +} diff --git a/pkg/adapter/metric/prometheus.go b/pkg/adapter/metric/prometheus.go new file mode 100644 index 00000000..1d3488c6 --- /dev/null +++ b/pkg/adapter/metric/prometheus.go @@ -0,0 +1,99 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/server/web" +) + +func PrometheusMiddleWare(next http.Handler) http.Handler { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status", "duration"}) + + prometheus.MustRegister(summaryVec) + + registerBuildInfo() + + return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { + start := time.Now() + next.ServeHTTP(writer, q) + end := time.Now() + go report(end.Sub(start), writer, q, summaryVec) + }) +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": web.BConfig.AppName, + "build_version": web.BuildVersion, + "build_revision": web.BuildGitRevision, + "build_status": web.BuildStatus, + "build_tag": web.BuildTag, + "build_time": strings.Replace(web.BuildTime, "--", " ", 1), + "go_version": web.GoVersion, + "git_branch": web.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { + ctrl := web.BeeApp.Handlers + ctx := ctrl.GetContext() + ctx.Reset(writer, q) + defer ctrl.GiveBackContext(ctx) + + // We cannot read the status code from q.Response.StatusCode + // since the http server does not set q.Response. So q.Response is nil + // Thus, we use reflection to read the status from writer whose concrete type is http.response + responseVal := reflect.ValueOf(writer).Elem() + field := responseVal.FieldByName("status") + status := -1 + if field.IsValid() && field.Kind() == reflect.Int { + status = int(field.Int()) + } + ptn := "UNKNOWN" + if rt, found := ctrl.FindRouter(ctx); found { + ptn = rt.GetPattern() + } else { + logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) +} diff --git a/pkg/adapter/metric/prometheus_test.go b/pkg/adapter/metric/prometheus_test.go new file mode 100644 index 00000000..d82a6dec --- /dev/null +++ b/pkg/adapter/metric/prometheus_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/context" +) + +func TestPrometheusMiddleWare(t *testing.T) { + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + writer := &context.Response{} + request := &http.Request{ + URL: &url.URL{ + Host: "localhost", + RawPath: "/a/b/c", + }, + Method: "POST", + } + vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) + + report(time.Second, writer, request, vec) + middleware.ServeHTTP(writer, request) +} From bdd8df675135f0c3b716130cbdf363c0ddf79567 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 21:01:54 +0800 Subject: [PATCH 19/35] adapt migration --- pkg/adapter/grace/grace.go | 2 - pkg/adapter/migration/ddl.go | 198 +++++++++++++++++++++++++++++ pkg/adapter/migration/doc.go | 32 +++++ pkg/adapter/migration/migration.go | 111 ++++++++++++++++ pkg/client/orm/migration/ddl.go | 52 ++++---- 5 files changed, 367 insertions(+), 28 deletions(-) create mode 100644 pkg/adapter/migration/ddl.go create mode 100644 pkg/adapter/migration/doc.go create mode 100644 pkg/adapter/migration/migration.go diff --git a/pkg/adapter/grace/grace.go b/pkg/adapter/grace/grace.go index 67cd4a1e..3775e395 100644 --- a/pkg/adapter/grace/grace.go +++ b/pkg/adapter/grace/grace.go @@ -66,7 +66,6 @@ const ( var ( - // DefaultReadTimeOut is the HTTP read timeout DefaultReadTimeOut time.Duration // DefaultWriteTimeOut is the HTTP Write timeout @@ -75,7 +74,6 @@ var ( DefaultMaxHeaderBytes int // DefaultTimeout is the shutdown server's timeout. default is 60s DefaultTimeout = grace.DefaultTimeout - ) // NewServer returns a new graceServer. diff --git a/pkg/adapter/migration/ddl.go b/pkg/adapter/migration/ddl.go new file mode 100644 index 00000000..97e45dec --- /dev/null +++ b/pkg/adapter/migration/ddl.go @@ -0,0 +1,198 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// Index struct defines the structure of Index Columns +type Index migration.Index + +// Unique struct defines a single unique key combination +type Unique migration.Unique + +// Column struct defines a single column of a table +type Column migration.Column + +// Foreign struct defines a single foreign relationship +type Foreign migration.Foreign + +// RenameColumn struct allows renaming of columns +type RenameColumn migration.RenameColumn + +// CreateTable creates the table on system +func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { + (*migration.Migration)(m).CreateTable(tablename, engine, charset, p...) +} + +// AlterTable set the ModifyType to alter +func (m *Migration) AlterTable(tablename string) { + (*migration.Migration)(m).AlterTable(tablename) +} + +// NewCol creates a new standard column and attaches it to m struct +func (m *Migration) NewCol(name string) *Column { + return (*Column)((*migration.Migration)(m).NewCol(name)) +} + +// PriCol creates a new primary column and attaches it to m struct +func (m *Migration) PriCol(name string) *Column { + return (*Column)((*migration.Migration)(m).PriCol(name)) +} + +// UniCol creates / appends columns to specified unique key and attaches it to m struct +func (m *Migration) UniCol(uni, name string) *Column { + return (*Column)((*migration.Migration)(m).UniCol(uni, name)) +} + +// ForeignCol creates a new foreign column and returns the instance of column +func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { + return (*Foreign)((*migration.Migration)(m).ForeignCol(colname, foreigncol, foreigntable)) +} + +// SetOnDelete sets the on delete of foreign +func (foreign *Foreign) SetOnDelete(del string) *Foreign { + (*migration.Foreign)(foreign).SetOnDelete(del) + return foreign +} + +// SetOnUpdate sets the on update of foreign +func (foreign *Foreign) SetOnUpdate(update string) *Foreign { + (*migration.Foreign)(foreign).SetOnUpdate(update) + return foreign +} + +// Remove marks the columns to be removed. +// it allows reverse m to create the column. +func (c *Column) Remove() { + (*migration.Column)(c).Remove() +} + +// SetAuto enables auto_increment of column (can be used once) +func (c *Column) SetAuto(inc bool) *Column { + (*migration.Column)(c).SetAuto(inc) + return c +} + +// SetNullable sets the column to be null +func (c *Column) SetNullable(null bool) *Column { + (*migration.Column)(c).SetNullable(null) + return c +} + +// SetDefault sets the default value, prepend with "DEFAULT " +func (c *Column) SetDefault(def string) *Column { + (*migration.Column)(c).SetDefault(def) + return c +} + +// SetUnsigned sets the column to be unsigned int +func (c *Column) SetUnsigned(unsign bool) *Column { + (*migration.Column)(c).SetUnsigned(unsign) + return c +} + +// SetDataType sets the dataType of the column +func (c *Column) SetDataType(dataType string) *Column { + (*migration.Column)(c).SetDataType(dataType) + return c +} + +// SetOldNullable allows reverting to previous nullable on reverse ms +func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldNullable(null) + return c +} + +// SetOldDefault allows reverting to previous default on reverse ms +func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDefault(def) + return c +} + +// SetOldUnsigned allows reverting to previous unsgined on reverse ms +func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldUnsigned(unsign) + return c +} + +// SetOldDataType allows reverting to previous datatype on reverse ms +func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDataType(dataType) + return c +} + +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +func (c *Column) SetPrimary(m *Migration) *Column { + (*migration.Column)(c).SetPrimary((*migration.Migration)(m)) + return c +} + +// AddColumnsToUnique adds the columns to Unique Struct +func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { + cls := toNewColumnsArray(columns) + (*migration.Unique)(unique).AddColumnsToUnique(cls...) + return unique +} + +// AddColumns adds columns to m struct +func (m *Migration) AddColumns(columns ...*Column) *Migration { + cls := toNewColumnsArray(columns) + (*migration.Migration)(m).AddColumns(cls...) + return m +} + +func toNewColumnsArray(columns []*Column) []*migration.Column { + cls := make([]*migration.Column, 0, len(columns)) + for _, c := range columns { + cls = append(cls, (*migration.Column)(c)) + } + return cls +} + +// AddPrimary adds the column to primary in m struct +func (m *Migration) AddPrimary(primary *Column) *Migration { + (*migration.Migration)(m).AddPrimary((*migration.Column)(primary)) + return m +} + +// AddUnique adds the column to unique in m struct +func (m *Migration) AddUnique(unique *Unique) *Migration { + (*migration.Migration)(m).AddUnique((*migration.Unique)(unique)) + return m +} + +// AddForeign adds the column to foreign in m struct +func (m *Migration) AddForeign(foreign *Foreign) *Migration { + (*migration.Migration)(m).AddForeign((*migration.Foreign)(foreign)) + return m +} + +// AddIndex adds the column to index in m struct +func (m *Migration) AddIndex(index *Index) *Migration { + (*migration.Migration)(m).AddIndex((*migration.Index)(index)) + return m +} + +// RenameColumn allows renaming of columns +func (m *Migration) RenameColumn(from, to string) *RenameColumn { + return (*RenameColumn)((*migration.Migration)(m).RenameColumn(from, to)) +} + +// GetSQL returns the generated sql depending on ModifyType +func (m *Migration) GetSQL() (sql string) { + return (*migration.Migration)(m).GetSQL() +} diff --git a/pkg/adapter/migration/doc.go b/pkg/adapter/migration/doc.go new file mode 100644 index 00000000..0c6564d4 --- /dev/null +++ b/pkg/adapter/migration/doc.go @@ -0,0 +1,32 @@ +// Package migration enables you to generate migrations back and forth. It generates both migrations. +// +// //Creates a table +// m.CreateTable("tablename","InnoDB","utf8"); +// +// //Alter a table +// m.AlterTable("tablename") +// +// Standard Column Methods +// * SetDataType +// * SetNullable +// * SetDefault +// * SetUnsigned (use only on integer types unless produces error) +// +// //Sets a primary column, multiple calls allowed, standard column methods available +// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) +// +// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index +// m.UniCol("index","column") +// +// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove +// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) +// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) +// +// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to +// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") +// m.RenameColumn("from","to")... +// +// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. +// //Supports standard column methods, automatic reverse. +// m.ForeignCol("local_col","foreign_col","foreign_table") +package migration diff --git a/pkg/adapter/migration/migration.go b/pkg/adapter/migration/migration.go new file mode 100644 index 00000000..4ee22e5a --- /dev/null +++ b/pkg/adapter/migration/migration.go @@ -0,0 +1,111 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package migration is used for migration +// +// The table structure is as follow: +// +// CREATE TABLE `migrations` ( +// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', +// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', +// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', +// `statements` longtext COMMENT 'SQL statements for this migration', +// `rollback_statements` longtext, +// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', +// PRIMARY KEY (`id_migration`) +// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// const the data format for the bee generate migration datatype +const ( + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" +) + +// Migrationer is an interface for all Migration struct +type Migrationer interface { + Up() + Down() + Reset() + Exec(name, status string) error + GetCreated() int64 +} + +// Migration defines the migrations by either SQL or DDL +type Migration migration.Migration + +// Up implement in the Inheritance struct for upgrade +func (m *Migration) Up() { + (*migration.Migration)(m).Up() +} + +// Down implement in the Inheritance struct for down +func (m *Migration) Down() { + (*migration.Migration)(m).Down() +} + +// Migrate adds the SQL to the execution list +func (m *Migration) Migrate(migrationType string) { + (*migration.Migration)(m).Migrate(migrationType) +} + +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { + (*migration.Migration)(m).SQL(sql) +} + +// Reset the sqls +func (m *Migration) Reset() { + (*migration.Migration)(m).Reset() +} + +// Exec execute the sql already add in the sql +func (m *Migration) Exec(name, status string) error { + return (*migration.Migration)(m).Exec(name, status) +} + +// GetCreated get the unixtime from the Created +func (m *Migration) GetCreated() int64 { + return (*migration.Migration)(m).GetCreated() +} + +// Register register the Migration in the map +func Register(name string, m Migrationer) error { + return migration.Register(name, m) +} + +// Upgrade upgrade the migration from lasttime +func Upgrade(lasttime int64) error { + return migration.Upgrade(lasttime) +} + +// Rollback rollback the migration by the name +func Rollback(name string) error { + return migration.Rollback(name) +} + +// Reset reset all migration +// run all migration's down function +func Reset() error { + return migration.Reset() +} + +// Refresh first Reset, then Upgrade +func Refresh() error { + return migration.Refresh() +} diff --git a/pkg/client/orm/migration/ddl.go b/pkg/client/orm/migration/ddl.go index c21352a8..e8b13212 100644 --- a/pkg/client/orm/migration/ddl.go +++ b/pkg/client/orm/migration/ddl.go @@ -31,7 +31,7 @@ type Unique struct { Columns []*Column } -//Column struct defines a single column of a table +// Column struct defines a single column of a table type Column struct { Name string Inc string @@ -84,7 +84,7 @@ func (m *Migration) NewCol(name string) *Column { return col } -//PriCol creates a new primary column and attaches it to m struct +// PriCol creates a new primary column and attaches it to m struct func (m *Migration) PriCol(name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -92,7 +92,7 @@ func (m *Migration) PriCol(name string) *Column { return col } -//UniCol creates / appends columns to specified unique key and attaches it to m struct +// UniCol creates / appends columns to specified unique key and attaches it to m struct func (m *Migration) UniCol(uni, name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -114,7 +114,7 @@ func (m *Migration) UniCol(uni, name string) *Column { return col } -//ForeignCol creates a new foreign column and returns the instance of column +// ForeignCol creates a new foreign column and returns the instance of column func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} @@ -123,25 +123,25 @@ func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreig return foreign } -//SetOnDelete sets the on delete of foreign +// SetOnDelete sets the on delete of foreign func (foreign *Foreign) SetOnDelete(del string) *Foreign { foreign.OnDelete = "ON DELETE" + del return foreign } -//SetOnUpdate sets the on update of foreign +// SetOnUpdate sets the on update of foreign func (foreign *Foreign) SetOnUpdate(update string) *Foreign { foreign.OnUpdate = "ON UPDATE" + update return foreign } -//Remove marks the columns to be removed. -//it allows reverse m to create the column. +// Remove marks the columns to be removed. +// it allows reverse m to create the column. func (c *Column) Remove() { c.remove = true } -//SetAuto enables auto_increment of column (can be used once) +// SetAuto enables auto_increment of column (can be used once) func (c *Column) SetAuto(inc bool) *Column { if inc { c.Inc = "auto_increment" @@ -149,7 +149,7 @@ func (c *Column) SetAuto(inc bool) *Column { return c } -//SetNullable sets the column to be null +// SetNullable sets the column to be null func (c *Column) SetNullable(null bool) *Column { if null { c.Null = "" @@ -160,13 +160,13 @@ func (c *Column) SetNullable(null bool) *Column { return c } -//SetDefault sets the default value, prepend with "DEFAULT " +// SetDefault sets the default value, prepend with "DEFAULT " func (c *Column) SetDefault(def string) *Column { c.Default = "DEFAULT " + def return c } -//SetUnsigned sets the column to be unsigned int +// SetUnsigned sets the column to be unsigned int func (c *Column) SetUnsigned(unsign bool) *Column { if unsign { c.Unsign = "UNSIGNED" @@ -174,13 +174,13 @@ func (c *Column) SetUnsigned(unsign bool) *Column { return c } -//SetDataType sets the dataType of the column +// SetDataType sets the dataType of the column func (c *Column) SetDataType(dataType string) *Column { c.DataType = dataType return c } -//SetOldNullable allows reverting to previous nullable on reverse ms +// SetOldNullable allows reverting to previous nullable on reverse ms func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { if null { c.OldNull = "" @@ -191,13 +191,13 @@ func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { return c } -//SetOldDefault allows reverting to previous default on reverse ms +// SetOldDefault allows reverting to previous default on reverse ms func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { c.OldDefault = def return c } -//SetOldUnsigned allows reverting to previous unsgined on reverse ms +// SetOldUnsigned allows reverting to previous unsgined on reverse ms func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { if unsign { c.OldUnsign = "UNSIGNED" @@ -205,19 +205,19 @@ func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { return c } -//SetOldDataType allows reverting to previous datatype on reverse ms +// SetOldDataType allows reverting to previous datatype on reverse ms func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { c.OldDataType = dataType return c } -//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) func (c *Column) SetPrimary(m *Migration) *Column { m.Primary = append(m.Primary, c) return c } -//AddColumnsToUnique adds the columns to Unique Struct +// AddColumnsToUnique adds the columns to Unique Struct func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { unique.Columns = append(unique.Columns, columns...) @@ -225,7 +225,7 @@ func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { return unique } -//AddColumns adds columns to m struct +// AddColumns adds columns to m struct func (m *Migration) AddColumns(columns ...*Column) *Migration { m.Columns = append(m.Columns, columns...) @@ -233,38 +233,38 @@ func (m *Migration) AddColumns(columns ...*Column) *Migration { return m } -//AddPrimary adds the column to primary in m struct +// AddPrimary adds the column to primary in m struct func (m *Migration) AddPrimary(primary *Column) *Migration { m.Primary = append(m.Primary, primary) return m } -//AddUnique adds the column to unique in m struct +// AddUnique adds the column to unique in m struct func (m *Migration) AddUnique(unique *Unique) *Migration { m.Uniques = append(m.Uniques, unique) return m } -//AddForeign adds the column to foreign in m struct +// AddForeign adds the column to foreign in m struct func (m *Migration) AddForeign(foreign *Foreign) *Migration { m.Foreigns = append(m.Foreigns, foreign) return m } -//AddIndex adds the column to index in m struct +// AddIndex adds the column to index in m struct func (m *Migration) AddIndex(index *Index) *Migration { m.Indexes = append(m.Indexes, index) return m } -//RenameColumn allows renaming of columns +// RenameColumn allows renaming of columns func (m *Migration) RenameColumn(from, to string) *RenameColumn { rename := &RenameColumn{OldName: from, NewName: to} m.Renames = append(m.Renames, rename) return rename } -//GetSQL returns the generated sql depending on ModifyType +// GetSQL returns the generated sql depending on ModifyType func (m *Migration) GetSQL() (sql string) { sql = "" switch m.ModifyType { From cbd51616f17361706060c8e7d1dab4265e519d8c Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 23:23:48 +0800 Subject: [PATCH 20/35] adapter: validation module --- pkg/adapter/validation/util.go | 62 +++ pkg/adapter/validation/validation.go | 274 ++++++++++ pkg/adapter/validation/validation_test.go | 609 ++++++++++++++++++++++ pkg/adapter/validation/validators.go | 512 ++++++++++++++++++ 4 files changed, 1457 insertions(+) create mode 100644 pkg/adapter/validation/util.go create mode 100644 pkg/adapter/validation/validation.go create mode 100644 pkg/adapter/validation/validation_test.go create mode 100644 pkg/adapter/validation/validators.go diff --git a/pkg/adapter/validation/util.go b/pkg/adapter/validation/util.go new file mode 100644 index 00000000..729712e0 --- /dev/null +++ b/pkg/adapter/validation/util.go @@ -0,0 +1,62 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "reflect" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +const ( + // ValidTag struct tag + ValidTag = validation.ValidTag + + LabelTag = validation.LabelTag +) + +var ( + ErrInt64On32 = validation.ErrInt64On32 +) + +// CustomFunc is for custom validate function +type CustomFunc func(v *Validation, obj interface{}, key string) + +// AddCustomFunc Add a custom function to validation +// The name can not be: +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// If the name is same with exists function, it will replace the origin valid function +func AddCustomFunc(name string, f CustomFunc) error { + return validation.AddCustomFunc(name, func(v *validation.Validation, obj interface{}, key string) { + f((*Validation)(v), obj, key) + }) +} + +// ValidFunc Valid function type +type ValidFunc validation.ValidFunc + +// Funcs Validate function map +type Funcs validation.Funcs + +// Call validate values with named type string +func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + return (validation.Funcs(f)).Call(name, params...) +} diff --git a/pkg/adapter/validation/validation.go b/pkg/adapter/validation/validation.go new file mode 100644 index 00000000..1cdb8dda --- /dev/null +++ b/pkg/adapter/validation/validation.go @@ -0,0 +1,274 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package validation for validations +// +// import ( +// "github.com/astaxie/beego/validation" +// "log" +// ) +// +// type User struct { +// Name string +// Age int +// } +// +// func main() { +// u := User{"man", 40} +// valid := validation.Validation{} +// valid.Required(u.Name, "name") +// valid.MaxSize(u.Name, 15, "nameMax") +// valid.Range(u.Age, 0, 140, "age") +// if valid.HasErrors() { +// // validation does not pass +// // print invalid message +// for _, err := range valid.Errors { +// log.Println(err.Key, err.Message) +// } +// } +// // or use like this +// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { +// log.Println(v.Error.Key, v.Error.Message) +// } +// } +// +// more info: http://beego.me/docs/mvc/controller/validation.md +package validation + +import ( + "fmt" + "regexp" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// ValidFormer valid interface +type ValidFormer interface { + Valid(*Validation) +} + +// Error show the error +type Error validation.Error + +// String Returns the Message. +func (e *Error) String() string { + if e == nil { + return "" + } + return e.Message +} + +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + +// Result is returned from every validation method. +// It provides an indication of success, and a pointer to the Error (if any). +type Result validation.Result + +// Key Get Result by given key string. +func (r *Result) Key(key string) *Result { + if r.Error != nil { + r.Error.Key = key + } + return r +} + +// Message Set Result message by string or format string with args +func (r *Result) Message(message string, args ...interface{}) *Result { + if r.Error != nil { + if len(args) == 0 { + r.Error.Message = message + } else { + r.Error.Message = fmt.Sprintf(message, args...) + } + } + return r +} + +// A Validation context manages data validation and error messages. +type Validation validation.Validation + +// Clear Clean all ValidationError. +func (v *Validation) Clear() { + (*validation.Validation)(v).Clear() +} + +// HasErrors Has ValidationError nor not. +func (v *Validation) HasErrors() bool { + return (*validation.Validation)(v).HasErrors() +} + +// ErrorMap Return the errors mapped by key. +// If there are multiple validation errors associated with a single key, the +// first one "wins". (Typically the first validation will be the more basic). +func (v *Validation) ErrorMap() map[string][]*Error { + newErrors := (*validation.Validation)(v).ErrorMap() + res := make(map[string][]*Error, len(newErrors)) + for n, es := range newErrors { + errs := make([]*Error, 0, len(es)) + + for _, e := range es { + errs = append(errs, (*Error)(e)) + } + + res[n] = errs + } + return res +} + +// Error Add an error to the validation context. +func (v *Validation) Error(message string, args ...interface{}) *Result { + return (*Result)((*validation.Validation)(v).Error(message, args...)) +} + +// Required Test that the argument is non-nil and non-empty (if string or list) +func (v *Validation) Required(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Required(obj, key)) +} + +// Min Test that the obj is greater than min if obj's type is int +func (v *Validation) Min(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).Min(obj, min, key)) +} + +// Max Test that the obj is less than max if obj's type is int +func (v *Validation) Max(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Max(obj, max, key)) +} + +// Range Test that the obj is between mni and max if obj's type is int +func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Range(obj, min, max, key)) +} + +// MinSize Test that the obj is longer than min size if type is string or slice +func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).MinSize(obj, min, key)) +} + +// MaxSize Test that the obj is shorter than max size if type is string or slice +func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).MaxSize(obj, max, key)) +} + +// Length Test that the obj is same length to n if type is string or slice +func (v *Validation) Length(obj interface{}, n int, key string) *Result { + return (*Result)((*validation.Validation)(v).Length(obj, n, key)) +} + +// Alpha Test that the obj is [a-zA-Z] if type is string +func (v *Validation) Alpha(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Alpha(obj, key)) +} + +// Numeric Test that the obj is [0-9] if type is string +func (v *Validation) Numeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Numeric(obj, key)) +} + +// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string +func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaNumeric(obj, key)) +} + +// Match Test that the obj matches regexp if type is string +func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).Match(obj, regex, key)) +} + +// NoMatch Test that the obj doesn't match regexp if type is string +func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).NoMatch(obj, regex, key)) +} + +// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string +func (v *Validation) AlphaDash(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaDash(obj, key)) +} + +// Email Test that the obj is email address if type is string +func (v *Validation) Email(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Email(obj, key)) +} + +// IP Test that the obj is IP address if type is string +func (v *Validation) IP(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).IP(obj, key)) +} + +// Base64 Test that the obj is base64 encoded if type is string +func (v *Validation) Base64(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Base64(obj, key)) +} + +// Mobile Test that the obj is chinese mobile number if type is string +func (v *Validation) Mobile(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Mobile(obj, key)) +} + +// Tel Test that the obj is chinese telephone number if type is string +func (v *Validation) Tel(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Tel(obj, key)) +} + +// Phone Test that the obj is chinese mobile or telephone number if type is string +func (v *Validation) Phone(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Phone(obj, key)) +} + +// ZipCode Test that the obj is chinese zip code if type is string +func (v *Validation) ZipCode(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).ZipCode(obj, key)) +} + +// key must like aa.bb.cc or aa.bb. +// AddError adds independent error message for the provided key +func (v *Validation) AddError(key, message string) { + (*validation.Validation)(v).AddError(key, message) +} + +// SetError Set error message for one field in ValidationError +func (v *Validation) SetError(fieldName string, errMsg string) *Error { + return (*Error)((*validation.Validation)(v).SetError(fieldName, errMsg)) +} + +// Check Apply a group of validators to a field, in order, and return the +// ValidationResult from the first one that fails, or the last one that +// succeeds. +func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { + vldts := make([]validation.Validator, 0, len(checks)) + for _, v := range checks { + vldts = append(vldts, validation.Validator(v)) + } + return (*Result)((*validation.Validation)(v).Check(obj, vldts...)) +} + +// Valid Validate a struct. +// the obj parameter must be a struct or a struct pointer +func (v *Validation) Valid(obj interface{}) (b bool, err error) { + return (*validation.Validation)(v).Valid(obj) +} + +// RecursiveValid Recursively validate a struct. +// Step1: Validate by v.Valid +// Step2: If pass on step1, then reflect obj's fields +// Step3: Do the Recursively validation to all struct or struct pointer fields +func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { + return (*validation.Validation)(v).RecursiveValid(objc) +} + +func (v *Validation) CanSkipAlso(skipFunc string) { + (*validation.Validation)(v).CanSkipAlso(skipFunc) +} diff --git a/pkg/adapter/validation/validation_test.go b/pkg/adapter/validation/validation_test.go new file mode 100644 index 00000000..b4b5b1b6 --- /dev/null +++ b/pkg/adapter/validation/validation_test.go @@ -0,0 +1,609 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "regexp" + "testing" + "time" +) + +func TestRequired(t *testing.T) { + valid := Validation{} + + if valid.Required(nil, "nil").Ok { + t.Error("nil object should be false") + } + if !valid.Required(true, "bool").Ok { + t.Error("Bool value should always return true") + } + if !valid.Required(false, "bool").Ok { + t.Error("Bool value should always return true") + } + if valid.Required("", "string").Ok { + t.Error("\"'\" string should be false") + } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } + if !valid.Required("astaxie", "string").Ok { + t.Error("string should be true") + } + if valid.Required(0, "zero").Ok { + t.Error("Integer should not be equal 0") + } + if !valid.Required(1, "int").Ok { + t.Error("Integer except 0 should be true") + } + if !valid.Required(time.Now(), "time").Ok { + t.Error("time should be true") + } + if valid.Required([]string{}, "emptySlice").Ok { + t.Error("empty slice should be false") + } + if !valid.Required([]interface{}{"ok"}, "slice").Ok { + t.Error("slice should be true") + } +} + +func TestMin(t *testing.T) { + valid := Validation{} + + if valid.Min(-1, 0, "min0").Ok { + t.Error("-1 is less than the minimum value of 0 should be false") + } + if !valid.Min(1, 0, "min0").Ok { + t.Error("1 is greater or equal than the minimum value of 0 should be true") + } +} + +func TestMax(t *testing.T) { + valid := Validation{} + + if valid.Max(1, 0, "max0").Ok { + t.Error("1 is greater than the minimum value of 0 should be false") + } + if !valid.Max(-1, 0, "max0").Ok { + t.Error("-1 is less or equal than the maximum value of 0 should be true") + } +} + +func TestRange(t *testing.T) { + valid := Validation{} + + if valid.Range(-1, 0, 1, "range0_1").Ok { + t.Error("-1 is between 0 and 1 should be false") + } + if !valid.Range(1, 0, 1, "range0_1").Ok { + t.Error("1 is between 0 and 1 should be true") + } +} + +func TestMinSize(t *testing.T) { + valid := Validation{} + + if valid.MinSize("", 1, "minSize1").Ok { + t.Error("the length of \"\" is less than the minimum value of 1 should be false") + } + if !valid.MinSize("ok", 1, "minSize1").Ok { + t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") + } + if valid.MinSize([]string{}, 1, "minSize1").Ok { + t.Error("the length of empty slice is less than the minimum value of 1 should be false") + } + if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { + t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") + } +} + +func TestMaxSize(t *testing.T) { + valid := Validation{} + + if valid.MaxSize("ok", 1, "maxSize1").Ok { + t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize("", 1, "maxSize1").Ok { + t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") + } + if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { + t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { + t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") + } +} + +func TestLength(t *testing.T) { + valid := Validation{} + + if valid.Length("", 1, "length1").Ok { + t.Error("the length of \"\" must equal 1 should be false") + } + if !valid.Length("1", 1, "length1").Ok { + t.Error("the length of \"1\" must equal 1 should be true") + } + if valid.Length([]string{}, 1, "length1").Ok { + t.Error("the length of empty slice must equal 1 should be false") + } + if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { + t.Error("the length of [\"ok\"] must equal 1 should be true") + } +} + +func TestAlpha(t *testing.T) { + valid := Validation{} + + if valid.Alpha("a,1-@ $", "alpha").Ok { + t.Error("\"a,1-@ $\" are valid alpha characters should be false") + } + if !valid.Alpha("abCD", "alpha").Ok { + t.Error("\"abCD\" are valid alpha characters should be true") + } +} + +func TestNumeric(t *testing.T) { + valid := Validation{} + + if valid.Numeric("a,1-@ $", "numeric").Ok { + t.Error("\"a,1-@ $\" are valid numeric characters should be false") + } + if !valid.Numeric("1234", "numeric").Ok { + t.Error("\"1234\" are valid numeric characters should be true") + } +} + +func TestAlphaNumeric(t *testing.T) { + valid := Validation{} + + if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") + } + if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { + t.Error("\"1234aB\" are valid alpha or numeric characters should be true") + } +} + +func TestMatch(t *testing.T) { + valid := Validation{} + + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") + } + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") + } +} + +func TestNoMatch(t *testing.T) { + valid := Validation{} + + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") + } + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") + } +} + +func TestAlphaDash(t *testing.T) { + valid := Validation{} + + if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") + } + if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { + t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") + } +} + +func TestEmail(t *testing.T) { + valid := Validation{} + + if valid.Email("not@a email", "email").Ok { + t.Error("\"not@a email\" is a valid email address should be false") + } + if !valid.Email("suchuangji@gmail.com", "email").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") + } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } +} + +func TestIP(t *testing.T) { + valid := Validation{} + + if valid.IP("11.255.255.256", "IP").Ok { + t.Error("\"11.255.255.256\" is a valid ip address should be false") + } + if !valid.IP("01.11.11.11", "IP").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") + } +} + +func TestBase64(t *testing.T) { + valid := Validation{} + + if valid.Base64("suchuangji@gmail.com", "base64").Ok { + t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") + } + if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { + t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") + } +} + +func TestMobile(t *testing.T) { + valid := Validation{} + + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } + } +} + +func TestTel(t *testing.T) { + valid := Validation{} + + if valid.Tel("222-00008888", "telephone").Ok { + t.Error("\"222-00008888\" is a valid telephone number should be false") + } + if !valid.Tel("022-70008888", "telephone").Ok { + t.Error("\"022-70008888\" is a valid telephone number should be true") + } + if !valid.Tel("02270008888", "telephone").Ok { + t.Error("\"02270008888\" is a valid telephone number should be true") + } + if !valid.Tel("70008888", "telephone").Ok { + t.Error("\"70008888\" is a valid telephone number should be true") + } +} + +func TestPhone(t *testing.T) { + valid := Validation{} + + if valid.Phone("222-00008888", "phone").Ok { + t.Error("\"222-00008888\" is a valid phone number should be false") + } + if !valid.Mobile("+8614700008888", "phone").Ok { + t.Error("\"+8614700008888\" is a valid phone number should be true") + } + if !valid.Tel("02270008888", "phone").Ok { + t.Error("\"02270008888\" is a valid phone number should be true") + } +} + +func TestZipCode(t *testing.T) { + valid := Validation{} + + if valid.ZipCode("", "zipcode").Ok { + t.Error("\"00008888\" is a valid zipcode should be false") + } + if !valid.ZipCode("536000", "zipcode").Ok { + t.Error("\"536000\" is a valid zipcode should be true") + } +} + +func TestValid(t *testing.T) { + type user struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + valid := Validation{} + + u := user{Name: "test@/test/;com", Age: 40} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Error("validation should be passed") + } + + uptr := &user{Name: "test", Age: 40} + valid.Clear() + b, err = valid.Valid(uptr) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Name.Match" { + t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + } + + u = user{Name: "test@/test/;com", Age: 180} + valid.Clear() + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) + } +} + +func TestRecursiveValid(t *testing.T) { + type User struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + type AnonymouseUser struct { + ID2 int + Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age2 int `valid:"Required;Range(1, 140)"` + } + + type Account struct { + Password string `valid:"Required"` + U User + AnonymouseUser + } + valid := Validation{} + + u := Account{Password: "abc123_", U: User{}} + b, err := valid.RecursiveValid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } +} + +func TestSkipValid(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + + IP string `valid:"IP"` + ReqIP string `valid:"Required;IP"` + + Mobile string `valid:"Mobile"` + ReqMobile string `valid:"Required;Mobile"` + + Tel string `valid:"Tel"` + ReqTel string `valid:"Required;Tel"` + + Phone string `valid:"Phone"` + ReqPhone string `valid:"Required;Phone"` + + ZipCode string `valid:"ZipCode"` + ReqZipCode string `valid:"Required;ZipCode"` + } + + u := User{ + ReqEmail: "a@a.com", + ReqIP: "127.0.0.1", + ReqMobile: "18888888888", + ReqTel: "02088888888", + ReqPhone: "02088888888", + ReqZipCode: "510000", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } +} + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} diff --git a/pkg/adapter/validation/validators.go b/pkg/adapter/validation/validators.go new file mode 100644 index 00000000..1a063749 --- /dev/null +++ b/pkg/adapter/validation/validators.go @@ -0,0 +1,512 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "sync" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty +var CanSkipFuncs = validation.CanSkipFuncs + +// MessageTmpls store commond validate template +var MessageTmpls = map[string]string{ + "Required": "Can not be empty", + "Min": "Minimum is %d", + "Max": "Maximum is %d", + "Range": "Range is %d to %d", + "MinSize": "Minimum size is %d", + "MaxSize": "Maximum size is %d", + "Length": "Required length is %d", + "Alpha": "Must be valid alpha characters", + "Numeric": "Must be valid numeric characters", + "AlphaNumeric": "Must be valid alpha or numeric characters", + "Match": "Must match %s", + "NoMatch": "Must not match %s", + "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", + "Email": "Must be a valid email address", + "IP": "Must be a valid ip address", + "Base64": "Must be valid base64 characters", + "Mobile": "Must be valid mobile number", + "Tel": "Must be valid telephone number", + "Phone": "Must be valid telephone or mobile phone number", + "ZipCode": "Must be valid zipcode", +} + +var once sync.Once + +// SetDefaultMessage set default messages +// if not set, the default messages are +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", +func SetDefaultMessage(msg map[string]string) { + validation.SetDefaultMessage(msg) +} + +// Validator interface +type Validator interface { + IsSatisfied(interface{}) bool + DefaultMessage() string + GetKey() string + GetLimitValue() interface{} +} + +// Required struct +type Required validation.Required + +// IsSatisfied judge whether obj has value +func (r Required) IsSatisfied(obj interface{}) bool { + return validation.Required(r).IsSatisfied(obj) +} + +// DefaultMessage return the default error message +func (r Required) DefaultMessage() string { + return validation.Required(r).DefaultMessage() +} + +// GetKey return the r.Key +func (r Required) GetKey() string { + return validation.Required(r).GetKey() +} + +// GetLimitValue return nil now +func (r Required) GetLimitValue() interface{} { + return validation.Required(r).GetLimitValue() +} + +// Min check struct +type Min validation.Min + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Min) IsSatisfied(obj interface{}) bool { + return validation.Min(m).IsSatisfied(obj) +} + +// DefaultMessage return the default min error message +func (m Min) DefaultMessage() string { + return validation.Min(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Min) GetKey() string { + return validation.Min(m).GetKey() +} + +// GetLimitValue return the limit value, Min +func (m Min) GetLimitValue() interface{} { + return validation.Min(m).GetLimitValue() +} + +// Max validate struct +type Max validation.Max + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Max) IsSatisfied(obj interface{}) bool { + return validation.Max(m).IsSatisfied(obj) +} + +// DefaultMessage return the default max error message +func (m Max) DefaultMessage() string { + return validation.Max(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Max) GetKey() string { + return validation.Max(m).GetKey() +} + +// GetLimitValue return the limit value, Max +func (m Max) GetLimitValue() interface{} { + return validation.Max(m).GetLimitValue() +} + +// Range Requires an integer to be within Min, Max inclusive. +type Range validation.Range + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (r Range) IsSatisfied(obj interface{}) bool { + return validation.Range(r).IsSatisfied(obj) +} + +// DefaultMessage return the default Range error message +func (r Range) DefaultMessage() string { + return validation.Range(r).DefaultMessage() +} + +// GetKey return the m.Key +func (r Range) GetKey() string { + return validation.Range(r).GetKey() +} + +// GetLimitValue return the limit value, Max +func (r Range) GetLimitValue() interface{} { + return validation.Range(r).GetLimitValue() +} + +// MinSize Requires an array or string to be at least a given length. +type MinSize validation.MinSize + +// IsSatisfied judge whether obj is valid +func (m MinSize) IsSatisfied(obj interface{}) bool { + return validation.MinSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MinSize error message +func (m MinSize) DefaultMessage() string { + return validation.MinSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MinSize) GetKey() string { + return validation.MinSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MinSize) GetLimitValue() interface{} { + return validation.MinSize(m).GetLimitValue() +} + +// MaxSize Requires an array or string to be at most a given length. +type MaxSize validation.MaxSize + +// IsSatisfied judge whether obj is valid +func (m MaxSize) IsSatisfied(obj interface{}) bool { + return validation.MaxSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MaxSize error message +func (m MaxSize) DefaultMessage() string { + return validation.MaxSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MaxSize) GetKey() string { + return validation.MaxSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MaxSize) GetLimitValue() interface{} { + return validation.MaxSize(m).GetLimitValue() +} + +// Length Requires an array or string to be exactly a given length. +type Length validation.Length + +// IsSatisfied judge whether obj is valid +func (l Length) IsSatisfied(obj interface{}) bool { + return validation.Length(l).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (l Length) DefaultMessage() string { + return validation.Length(l).DefaultMessage() +} + +// GetKey return the m.Key +func (l Length) GetKey() string { + return validation.Length(l).GetKey() +} + +// GetLimitValue return the limit value +func (l Length) GetLimitValue() interface{} { + return validation.Length(l).GetLimitValue() +} + +// Alpha check the alpha +type Alpha validation.Alpha + +// IsSatisfied judge whether obj is valid +func (a Alpha) IsSatisfied(obj interface{}) bool { + return validation.Alpha(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a Alpha) DefaultMessage() string { + return validation.Alpha(a).DefaultMessage() +} + +// GetKey return the m.Key +func (a Alpha) GetKey() string { + return validation.Alpha(a).GetKey() +} + +// GetLimitValue return the limit value +func (a Alpha) GetLimitValue() interface{} { + return validation.Alpha(a).GetLimitValue() +} + +// Numeric check number +type Numeric validation.Numeric + +// IsSatisfied judge whether obj is valid +func (n Numeric) IsSatisfied(obj interface{}) bool { + return validation.Numeric(n).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (n Numeric) DefaultMessage() string { + return validation.Numeric(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n Numeric) GetKey() string { + return validation.Numeric(n).GetKey() +} + +// GetLimitValue return the limit value +func (n Numeric) GetLimitValue() interface{} { + return validation.Numeric(n).GetLimitValue() +} + +// AlphaNumeric check alpha and number +type AlphaNumeric validation.AlphaNumeric + +// IsSatisfied judge whether obj is valid +func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { + return validation.AlphaNumeric(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a AlphaNumeric) DefaultMessage() string { + return validation.AlphaNumeric(a).DefaultMessage() +} + +// GetKey return the a.Key +func (a AlphaNumeric) GetKey() string { + return validation.AlphaNumeric(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaNumeric) GetLimitValue() interface{} { + return validation.AlphaNumeric(a).GetLimitValue() +} + +// Match Requires a string to match a given regex. +type Match validation.Match + +// IsSatisfied judge whether obj is valid +func (m Match) IsSatisfied(obj interface{}) bool { + return validation.Match(m).IsSatisfied(obj) +} + +// DefaultMessage return the default Match error message +func (m Match) DefaultMessage() string { + return validation.Match(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Match) GetKey() string { + return validation.Match(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Match) GetLimitValue() interface{} { + return validation.Match(m).GetLimitValue() +} + +// NoMatch Requires a string to not match a given regex. +type NoMatch validation.NoMatch + +// IsSatisfied judge whether obj is valid +func (n NoMatch) IsSatisfied(obj interface{}) bool { + return validation.NoMatch(n).IsSatisfied(obj) +} + +// DefaultMessage return the default NoMatch error message +func (n NoMatch) DefaultMessage() string { + return validation.NoMatch(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n NoMatch) GetKey() string { + return validation.NoMatch(n).GetKey() +} + +// GetLimitValue return the limit value +func (n NoMatch) GetLimitValue() interface{} { + return validation.NoMatch(n).GetLimitValue() +} + +// AlphaDash check not Alpha +type AlphaDash validation.AlphaDash + +// DefaultMessage return the default AlphaDash error message +func (a AlphaDash) DefaultMessage() string { + return validation.AlphaDash(a).DefaultMessage() +} + +// GetKey return the n.Key +func (a AlphaDash) GetKey() string { + return validation.AlphaDash(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaDash) GetLimitValue() interface{} { + return validation.AlphaDash(a).GetLimitValue() +} + +// Email check struct +type Email validation.Email + +// DefaultMessage return the default Email error message +func (e Email) DefaultMessage() string { + return validation.Email(e).DefaultMessage() +} + +// GetKey return the n.Key +func (e Email) GetKey() string { + return validation.Email(e).GetKey() +} + +// GetLimitValue return the limit value +func (e Email) GetLimitValue() interface{} { + return validation.Email(e).GetLimitValue() +} + +// IP check struct +type IP validation.IP + +// DefaultMessage return the default IP error message +func (i IP) DefaultMessage() string { + return validation.IP(i).DefaultMessage() +} + +// GetKey return the i.Key +func (i IP) GetKey() string { + return validation.IP(i).GetKey() +} + +// GetLimitValue return the limit value +func (i IP) GetLimitValue() interface{} { + return validation.IP(i).GetLimitValue() +} + +// Base64 check struct +type Base64 validation.Base64 + +// DefaultMessage return the default Base64 error message +func (b Base64) DefaultMessage() string { + return validation.Base64(b).DefaultMessage() +} + +// GetKey return the b.Key +func (b Base64) GetKey() string { + return validation.Base64(b).GetKey() +} + +// GetLimitValue return the limit value +func (b Base64) GetLimitValue() interface{} { + return validation.Base64(b).GetLimitValue() +} + +// Mobile check struct +type Mobile validation.Mobile + +// DefaultMessage return the default Mobile error message +func (m Mobile) DefaultMessage() string { + return validation.Mobile(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Mobile) GetKey() string { + return validation.Mobile(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Mobile) GetLimitValue() interface{} { + return validation.Mobile(m).GetLimitValue() +} + +// Tel check telephone struct +type Tel validation.Tel + +// DefaultMessage return the default Tel error message +func (t Tel) DefaultMessage() string { + return validation.Tel(t).DefaultMessage() +} + +// GetKey return the t.Key +func (t Tel) GetKey() string { + return validation.Tel(t).GetKey() +} + +// GetLimitValue return the limit value +func (t Tel) GetLimitValue() interface{} { + return validation.Tel(t).GetLimitValue() +} + +// Phone just for chinese telephone or mobile phone number +type Phone validation.Phone + +// IsSatisfied judge whether obj is valid +func (p Phone) IsSatisfied(obj interface{}) bool { + return validation.Phone(p).IsSatisfied(obj) +} + +// DefaultMessage return the default Phone error message +func (p Phone) DefaultMessage() string { + return validation.Phone(p).DefaultMessage() +} + +// GetKey return the p.Key +func (p Phone) GetKey() string { + return validation.Phone(p).GetKey() +} + +// GetLimitValue return the limit value +func (p Phone) GetLimitValue() interface{} { + return validation.Phone(p).GetLimitValue() +} + +// ZipCode check the zip struct +type ZipCode validation.ZipCode + +// DefaultMessage return the default Zip error message +func (z ZipCode) DefaultMessage() string { + return validation.ZipCode(z).DefaultMessage() +} + +// GetKey return the z.Key +func (z ZipCode) GetKey() string { + return validation.ZipCode(z).GetKey() +} + +// GetLimitValue return the limit value +func (z ZipCode) GetLimitValue() interface{} { + return validation.ZipCode(z).GetLimitValue() +} From 3530457ff9a51e721be139bec94de2299a027197 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 3 Sep 2020 21:34:46 +0800 Subject: [PATCH 21/35] Adapter: toolbox module --- pkg/adapter/toolbox/healthcheck.go | 52 +++++ pkg/adapter/toolbox/profile.go | 50 +++++ pkg/adapter/toolbox/profile_test.go | 28 +++ pkg/adapter/toolbox/statistics.go | 50 +++++ pkg/adapter/toolbox/statistics_test.go | 40 ++++ pkg/adapter/toolbox/task.go | 286 +++++++++++++++++++++++++ pkg/adapter/toolbox/task_test.go | 63 ++++++ pkg/task/task.go | 6 +- 8 files changed, 572 insertions(+), 3 deletions(-) create mode 100644 pkg/adapter/toolbox/healthcheck.go create mode 100644 pkg/adapter/toolbox/profile.go create mode 100644 pkg/adapter/toolbox/profile_test.go create mode 100644 pkg/adapter/toolbox/statistics.go create mode 100644 pkg/adapter/toolbox/statistics_test.go create mode 100644 pkg/adapter/toolbox/task.go create mode 100644 pkg/adapter/toolbox/task_test.go diff --git a/pkg/adapter/toolbox/healthcheck.go b/pkg/adapter/toolbox/healthcheck.go new file mode 100644 index 00000000..56be8089 --- /dev/null +++ b/pkg/adapter/toolbox/healthcheck.go @@ -0,0 +1,52 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package toolbox healthcheck +// +// type DatabaseCheck struct { +// } +// +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } +// +// AddHealthCheck("database",&DatabaseCheck{}) +// +// more docs: http://beego.me/docs/module/toolbox.md +package toolbox + +import ( + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +// AdminCheckList holds health checker map +// Deprecated using governor.AdminCheckList +var AdminCheckList map[string]HealthChecker + +// HealthChecker health checker interface +type HealthChecker governor.HealthChecker + +// AddHealthCheck add health checker with name string +func AddHealthCheck(name string, hc HealthChecker) { + governor.AddHealthCheck(name, hc) + AdminCheckList[name] = hc +} + +func init() { + AdminCheckList = make(map[string]HealthChecker) +} diff --git a/pkg/adapter/toolbox/profile.go b/pkg/adapter/toolbox/profile.go new file mode 100644 index 00000000..16cf80b1 --- /dev/null +++ b/pkg/adapter/toolbox/profile.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "io" + "os" + "time" + + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +var startTime = time.Now() +var pid int + +func init() { + pid = os.Getpid() +} + +// ProcessInput parse input command string +func ProcessInput(input string, w io.Writer) { + governor.ProcessInput(input, w) +} + +// MemProf record memory profile in pprof +func MemProf(w io.Writer) { + governor.MemProf(w) +} + +// GetCPUProfile start cpu profile monitor +func GetCPUProfile(w io.Writer) { + governor.GetCPUProfile(w) +} + +// PrintGCSummary print gc information to io.Writer +func PrintGCSummary(w io.Writer) { + governor.PrintGCSummary(w) +} diff --git a/pkg/adapter/toolbox/profile_test.go b/pkg/adapter/toolbox/profile_test.go new file mode 100644 index 00000000..07a20c4e --- /dev/null +++ b/pkg/adapter/toolbox/profile_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "os" + "testing" +) + +func TestProcessInput(t *testing.T) { + ProcessInput("lookup goroutine", os.Stdout) + ProcessInput("lookup heap", os.Stdout) + ProcessInput("lookup threadcreate", os.Stdout) + ProcessInput("lookup block", os.Stdout) + ProcessInput("gc summary", os.Stdout) +} diff --git a/pkg/adapter/toolbox/statistics.go b/pkg/adapter/toolbox/statistics.go new file mode 100644 index 00000000..b7d3bda9 --- /dev/null +++ b/pkg/adapter/toolbox/statistics.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Statistics struct +type Statistics web.Statistics + +// URLMap contains several statistics struct to log different data +type URLMap web.URLMap + +// AddStatistics add statistics task. +// it needs request method, request url, request controller and statistics time duration +func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { + (*web.URLMap)(m).AddStatistics(requestMethod, requestURL, requestController, requesttime) +} + +// GetMap put url statistics result in io.Writer +func (m *URLMap) GetMap() map[string]interface{} { + return (*web.URLMap)(m).GetMap() +} + +// GetMapData return all mapdata +func (m *URLMap) GetMapData() []map[string]interface{} { + return (*web.URLMap)(m).GetMapData() +} + +// StatisticsMap hosld global statistics data map +var StatisticsMap *URLMap + +func init() { + StatisticsMap = (*URLMap)(web.StatisticsMap) +} diff --git a/pkg/adapter/toolbox/statistics_test.go b/pkg/adapter/toolbox/statistics_test.go new file mode 100644 index 00000000..ac29476c --- /dev/null +++ b/pkg/adapter/toolbox/statistics_test.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "encoding/json" + "testing" + "time" +) + +func TestStatics(t *testing.T) { + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) + StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) + StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) + StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) +} diff --git a/pkg/adapter/toolbox/task.go b/pkg/adapter/toolbox/task.go new file mode 100644 index 00000000..2a6d9aa6 --- /dev/null +++ b/pkg/adapter/toolbox/task.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "context" + "sort" + "time" + + "github.com/astaxie/beego/pkg/task" +) + +// The bounds for each field. +var ( + AdminTaskList map[string]Tasker +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Schedule time taks schedule +type Schedule task.Schedule + +// TaskFunc task func type +type TaskFunc func() error + +// Tasker task interface +type Tasker interface { + GetSpec() string + GetStatus() string + Run() error + SetNext(time.Time) + GetNext() time.Time + SetPrev(time.Time) + GetPrev() time.Time +} + +// task error +type taskerr struct { + t time.Time + errinfo string +} + +// Task task struct +// Deprecated +type Task struct { + // Deprecated + Taskname string + // Deprecated + Spec *Schedule + // Deprecated + SpecStr string + // Deprecated + DoFunc TaskFunc + // Deprecated + Prev time.Time + // Deprecated + Next time.Time + // Deprecated + Errlist []*taskerr // like errtime:errinfo + // Deprecated + ErrLimit int // max length for the errlist, 0 stand for no limit + + delegate *task.Task +} + +// NewTask add new task with name, time and func +func NewTask(tname string, spec string, f TaskFunc) *Task { + + task := task.NewTask(tname, spec, func(ctx context.Context) error { + return f() + }) + return &Task{ + delegate: task, + } +} + +// GetSpec get spec string +func (t *Task) GetSpec() string { + t.initDelegate() + + return t.delegate.GetSpec(context.Background()) +} + +// GetStatus get current task status +func (t *Task) GetStatus() string { + + t.initDelegate() + + return t.delegate.GetStatus(context.Background()) +} + +// Run run all tasks +func (t *Task) Run() error { + t.initDelegate() + return t.delegate.Run(context.Background()) +} + +// SetNext set next time for this task +func (t *Task) SetNext(now time.Time) { + t.initDelegate() + t.delegate.SetNext(context.Background(), now) +} + +// GetNext get the next call time of this task +func (t *Task) GetNext() time.Time { + t.initDelegate() + return t.delegate.GetNext(context.Background()) +} + +// SetPrev set prev time of this task +func (t *Task) SetPrev(now time.Time) { + t.initDelegate() + t.delegate.SetPrev(context.Background(), now) +} + +// GetPrev get prev time of this task +func (t *Task) GetPrev() time.Time { + t.initDelegate() + return t.delegate.GetPrev(context.Background()) +} + +// six columns mean: +// second:0-59 +// minute:0-59 +// hour:1-23 +// day:1-31 +// month:1-12 +// week:0-6(0 means Sunday) + +// SetCron some signals: +// *: any time +// ,:  separate signal +//    -:duration +// /n : do as n times of time duration +// /////////////////////////////////////////////////////// +// 0/30 * * * * * every 30s +// 0 43 21 * * * 21:43 +// 0 15 05 * * *    05:15 +// 0 0 17 * * * 17:00 +// 0 0 17 * * 1 17:00 in every Monday +// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday +// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month +// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month +// 0 42 4 1 * *     4:42 on the 1st day of month +// 0 0 21 * * 1-6   21:00 from Monday to Saturday +// 0 0,10,20,30,40,50 * * * *  every 10 min duration +// 0 */10 * * * *        every 10 min duration +// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time +// 0 0 1 * * *         1:00 +// 0 0 */1 * * *        0 min of hour in 1 hour duration +// 0 0 * * * *         0 min of hour in 1 hour duration +// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 +// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month +func (t *Task) SetCron(spec string) { + t.initDelegate() + t.delegate.SetCron(spec) +} + +func (t *Task) initDelegate() { + if t.delegate == nil { + t.delegate = &task.Task{ + Taskname: t.Taskname, + Spec: (*task.Schedule)(t.Spec), + SpecStr: t.SpecStr, + DoFunc: func(ctx context.Context) error { + return t.DoFunc() + }, + Prev: t.Prev, + Next: t.Next, + ErrLimit: t.ErrLimit, + } + } +} + +// Next set schedule to next time +func (s *Schedule) Next(t time.Time) time.Time { + return (*task.Schedule)(s).Next(t) +} + +// StartTask start all tasks +func StartTask() { + task.StartTask() +} + +// StopTask stop all tasks +func StopTask() { + task.StopTask() +} + +// AddTask add task with name +func AddTask(taskname string, t Tasker) { + task.AddTask(taskname, &oldToNewAdapter{delegate: t}) +} + +// DeleteTask delete task with name +func DeleteTask(taskname string) { + task.DeleteTask(taskname) +} + +// MapSorter sort map for tasker +type MapSorter task.MapSorter + +// NewMapSorter create new tasker map +func NewMapSorter(m map[string]Tasker) *MapSorter { + + newTaskerMap := make(map[string]task.Tasker, len(m)) + + for key, value := range m { + newTaskerMap[key] = &oldToNewAdapter{ + delegate: value, + } + } + + return (*MapSorter)(task.NewMapSorter(newTaskerMap)) +} + +// Sort sort tasker map +func (ms *MapSorter) Sort() { + sort.Sort(ms) +} + +func (ms *MapSorter) Len() int { return len(ms.Keys) } +func (ms *MapSorter) Less(i, j int) bool { + if ms.Vals[i].GetNext(context.Background()).IsZero() { + return false + } + if ms.Vals[j].GetNext(context.Background()).IsZero() { + return true + } + return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) +} +func (ms *MapSorter) Swap(i, j int) { + ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] + ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] +} + +func init() { + AdminTaskList = make(map[string]Tasker) +} + +type oldToNewAdapter struct { + delegate Tasker +} + +func (o *oldToNewAdapter) GetSpec(ctx context.Context) string { + return o.delegate.GetSpec() +} + +func (o *oldToNewAdapter) GetStatus(ctx context.Context) string { + return o.delegate.GetStatus() +} + +func (o *oldToNewAdapter) Run(ctx context.Context) error { + return o.delegate.Run() +} + +func (o *oldToNewAdapter) SetNext(ctx context.Context, t time.Time) { + o.delegate.SetNext(t) +} + +func (o *oldToNewAdapter) GetNext(ctx context.Context) time.Time { + return o.delegate.GetNext() +} + +func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { + o.delegate.SetPrev(t) +} + +func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { + return o.delegate.GetPrev() +} diff --git a/pkg/adapter/toolbox/task_test.go b/pkg/adapter/toolbox/task_test.go new file mode 100644 index 00000000..596bc9c5 --- /dev/null +++ b/pkg/adapter/toolbox/task_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestParse(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + err := tk.Run() + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + time.Sleep(6 * time.Second) + StopTask() +} + +func TestSpec(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + + AddTask("tk1", tk1) + AddTask("tk2", tk2) + AddTask("tk3", tk3) + StartTask() + defer StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} diff --git a/pkg/task/task.go b/pkg/task/task.go index e2962000..bcadb956 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -83,7 +83,7 @@ type Schedule struct { } // TaskFunc task func type -type TaskFunc func() error +type TaskFunc func(ctx context.Context) error // Tasker task interface type Tasker interface { @@ -148,8 +148,8 @@ func (t *Task) GetStatus(context.Context) string { } // Run run all tasks -func (t *Task) Run(context.Context) error { - err := t.DoFunc() +func (t *Task) Run(ctx context.Context) error { + err := t.DoFunc(ctx) if err != nil { index := t.errCnt % t.ErrLimit t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} From 8ef9965eef3250a5578739258c9f12315ead1771 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 3 Sep 2020 23:36:09 +0800 Subject: [PATCH 22/35] Adapter: session module --- .../session/couchbase/sess_couchbase.go | 118 ++++++ pkg/adapter/session/ledis/ledis_session.go | 86 +++++ pkg/adapter/session/memcache/sess_memcache.go | 118 ++++++ pkg/adapter/session/mysql/sess_mysql.go | 135 +++++++ .../session/postgres/sess_postgresql.go | 139 ++++++++ pkg/adapter/session/provider_adapter.go | 104 ++++++ pkg/adapter/session/redis/sess_redis.go | 121 +++++++ .../session/redis_cluster/redis_cluster.go | 120 +++++++ .../redis_sentinel/sess_redis_sentinel.go | 121 +++++++ .../sess_redis_sentinel_test.go | 90 +++++ pkg/adapter/session/sess_cookie.go | 114 ++++++ pkg/adapter/session/sess_cookie_test.go | 105 ++++++ pkg/adapter/session/sess_file.go | 106 ++++++ pkg/adapter/session/sess_file_test.go | 336 ++++++++++++++++++ pkg/adapter/session/sess_mem.go | 106 ++++++ pkg/adapter/session/sess_mem_test.go | 58 +++ pkg/adapter/session/sess_test.go | 51 +++ pkg/adapter/session/sess_utils.go | 29 ++ pkg/adapter/session/session.go | 166 +++++++++ pkg/adapter/session/ssdb/sess_ssdb.go | 84 +++++ pkg/adapter/session/store_adapter.go | 84 +++++ pkg/infrastructure/session/sess_cookie.go | 2 +- pkg/infrastructure/session/sess_mem.go | 6 +- 23 files changed, 2395 insertions(+), 4 deletions(-) create mode 100644 pkg/adapter/session/couchbase/sess_couchbase.go create mode 100644 pkg/adapter/session/ledis/ledis_session.go create mode 100644 pkg/adapter/session/memcache/sess_memcache.go create mode 100644 pkg/adapter/session/mysql/sess_mysql.go create mode 100644 pkg/adapter/session/postgres/sess_postgresql.go create mode 100644 pkg/adapter/session/provider_adapter.go create mode 100644 pkg/adapter/session/redis/sess_redis.go create mode 100644 pkg/adapter/session/redis_cluster/redis_cluster.go create mode 100644 pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go create mode 100644 pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go create mode 100644 pkg/adapter/session/sess_cookie.go create mode 100644 pkg/adapter/session/sess_cookie_test.go create mode 100644 pkg/adapter/session/sess_file.go create mode 100644 pkg/adapter/session/sess_file_test.go create mode 100644 pkg/adapter/session/sess_mem.go create mode 100644 pkg/adapter/session/sess_mem_test.go create mode 100644 pkg/adapter/session/sess_test.go create mode 100644 pkg/adapter/session/sess_utils.go create mode 100644 pkg/adapter/session/session.go create mode 100644 pkg/adapter/session/ssdb/sess_ssdb.go create mode 100644 pkg/adapter/session/store_adapter.go diff --git a/pkg/adapter/session/couchbase/sess_couchbase.go b/pkg/adapter/session/couchbase/sess_couchbase.go new file mode 100644 index 00000000..bce09641 --- /dev/null +++ b/pkg/adapter/session/couchbase/sess_couchbase.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package couchbase for session provider +// +// depend on github.com/couchbaselabs/go-couchbasee +// +// go install github.com/couchbaselabs/go-couchbase +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/couchbase" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package couchbase + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beecb "github.com/astaxie/beego/pkg/infrastructure/session/couchbase" +) + +// SessionStore store each session +type SessionStore beecb.SessionStore + +// Provider couchabse provided +type Provider beecb.Provider + +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { + return (*beecb.SessionStore)(cs).Set(context.Background(), key, value) +} + +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { + return (*beecb.SessionStore)(cs).Get(context.Background(), key) +} + +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { + return (*beecb.SessionStore)(cs).Delete(context.Background(), key) +} + +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { + return (*beecb.SessionStore)(cs).Flush(context.Background()) +} + +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { + return (*beecb.SessionStore)(cs).SessionID(context.Background()) +} + +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beecb.SessionStore)(cs).SessionRelease(context.Background(), w) +} + +// SessionInit init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beecb.Provider)(cp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { + res, _ := (*beecb.Provider)(cp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { + return (*beecb.Provider)(cp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle +func (cp *Provider) SessionGC() { + (*beecb.Provider)(cp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (cp *Provider) SessionAll() int { + return (*beecb.Provider)(cp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/ledis/ledis_session.go b/pkg/adapter/session/ledis/ledis_session.go new file mode 100644 index 00000000..96198837 --- /dev/null +++ b/pkg/adapter/session/ledis/ledis_session.go @@ -0,0 +1,86 @@ +// Package ledis provide session Provider +package ledis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beeLedis "github.com/astaxie/beego/pkg/infrastructure/session/ledis" +) + +// SessionStore ledis session store +type SessionStore beeLedis.SessionStore + +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { + return (*beeLedis.SessionStore)(ls).Set(context.Background(), key, value) +} + +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { + return (*beeLedis.SessionStore)(ls).Get(context.Background(), key) +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { + return (*beeLedis.SessionStore)(ls).Delete(context.Background(), key) +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { + return (*beeLedis.SessionStore)(ls).Flush(context.Background()) +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { + return (*beeLedis.SessionStore)(ls).SessionID(context.Background()) +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeLedis.SessionStore)(ls).SessionRelease(context.Background(), w) +} + +// Provider ledis session provider +type Provider beeLedis.Provider + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeLedis.Provider)(lp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { + res, _ := (*beeLedis.Provider)(lp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { + return (*beeLedis.Provider)(lp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { + (*beeLedis.Provider)(lp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (lp *Provider) SessionAll() int { + return (*beeLedis.Provider)(lp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/memcache/sess_memcache.go b/pkg/adapter/session/memcache/sess_memcache.go new file mode 100644 index 00000000..8afa79aa --- /dev/null +++ b/pkg/adapter/session/memcache/sess_memcache.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for session provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/memcache" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package memcache + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beemem "github.com/astaxie/beego/pkg/infrastructure/session/memcache" +) + +// SessionStore memcache session store +type SessionStore beemem.SessionStore + +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beemem.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beemem.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beemem.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { + return (*beemem.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { + return (*beemem.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beemem.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// MemProvider memcache session provider +type MemProvider beemem.MemProvider + +// SessionInit init memcache session +// savepath like +// e.g. 127.0.0.1:9090 +func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*beemem.MemProvider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check memcache session exist by sid +func (rp *MemProvider) SessionExist(sid string) bool { + res, _ := (*beemem.MemProvider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete memcache session by id +func (rp *MemProvider) SessionDestroy(sid string) error { + return (*beemem.MemProvider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *MemProvider) SessionGC() { + (*beemem.MemProvider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *MemProvider) SessionAll() int { + return (*beemem.MemProvider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/mysql/sess_mysql.go b/pkg/adapter/session/mysql/sess_mysql.go new file mode 100644 index 00000000..1850a380 --- /dev/null +++ b/pkg/adapter/session/mysql/sess_mysql.go @@ -0,0 +1,135 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql for session provider +// +// depends on github.com/go-sql-driver/mysql: +// +// go install github.com/go-sql-driver/mysql +// +// mysql session support need create table as sql: +// CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/mysql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package mysql + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/pkg/infrastructure/session/mysql" + + // import mysql driver + _ "github.com/go-sql-driver/mysql" +) + +var ( + // TableName store the session in MySQL + TableName = mysql.TableName + mysqlpder = &Provider{} +) + +// SessionStore mysql session store +type SessionStore mysql.SessionStore + +// Set value in mysql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*mysql.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*mysql.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { + return (*mysql.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { + return (*mysql.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { + return (*mysql.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save mysql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*mysql.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider mysql session provider +type Provider mysql.Provider + +// SessionInit init mysql session. +// savepath is the connection string of mysql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*mysql.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*mysql.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*mysql.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { + (*mysql.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { + return (*mysql.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/postgres/sess_postgresql.go b/pkg/adapter/session/postgres/sess_postgresql.go new file mode 100644 index 00000000..de1adbc4 --- /dev/null +++ b/pkg/adapter/session/postgres/sess_postgresql.go @@ -0,0 +1,139 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres for session provider +// +// depends on github.com/lib/pq: +// +// go install github.com/lib/pq +// +// +// needs this table in your database: +// +// CREATE TABLE session ( +// session_key char(64) NOT NULL, +// session_data bytea, +// session_expiry timestamp NOT NULL, +// CONSTRAINT session_key PRIMARY KEY(session_key) +// ); +// +// will be activated with these settings in app.conf: +// +// SessionOn = true +// SessionProvider = postgresql +// SessionSavePath = "user=a password=b dbname=c sslmode=disable" +// SessionName = session +// +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/postgresql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package postgres + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + // import postgresql Driver + _ "github.com/lib/pq" + + "github.com/astaxie/beego/pkg/infrastructure/session/postgres" +) + +// SessionStore postgresql session store +type SessionStore postgres.SessionStore + +// Set value in postgresql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*postgres.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*postgres.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { + return (*postgres.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { + return (*postgres.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { + return (*postgres.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save postgresql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*postgres.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider postgresql session provider +type Provider postgres.Provider + +// SessionInit init postgresql session. +// savepath is the connection string of postgresql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*postgres.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*postgres.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*postgres.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { + (*postgres.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { + return (*postgres.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/provider_adapter.go b/pkg/adapter/session/provider_adapter.go new file mode 100644 index 00000000..11177a4d --- /dev/null +++ b/pkg/adapter/session/provider_adapter.go @@ -0,0 +1,104 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type oldToNewProviderAdapter struct { + delegate Provider +} + +func (o *oldToNewProviderAdapter) SessionInit(ctx context.Context, gclifetime int64, config string) error { + return o.delegate.SessionInit(gclifetime, config) +} + +func (o *oldToNewProviderAdapter) SessionRead(ctx context.Context, sid string) (session.Store, error) { + store, err := o.delegate.SessionRead(sid) + return &oldToNewStoreAdapter{ + delegate: store, + }, err +} + +func (o *oldToNewProviderAdapter) SessionExist(ctx context.Context, sid string) (bool, error) { + return o.delegate.SessionExist(sid), nil +} + +func (o *oldToNewProviderAdapter) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + s, err := o.delegate.SessionRegenerate(oldsid, sid) + return &oldToNewStoreAdapter{ + delegate: s, + }, err +} + +func (o *oldToNewProviderAdapter) SessionDestroy(ctx context.Context, sid string) error { + return o.delegate.SessionDestroy(sid) +} + +func (o *oldToNewProviderAdapter) SessionAll(ctx context.Context) int { + return o.delegate.SessionAll() +} + +func (o *oldToNewProviderAdapter) SessionGC(ctx context.Context) { + o.delegate.SessionGC() +} + +type newToOldProviderAdapter struct { + delegate session.Provider +} + +func (n *newToOldProviderAdapter) SessionInit(gclifetime int64, config string) error { + return n.delegate.SessionInit(context.Background(), gclifetime, config) +} + +func (n *newToOldProviderAdapter) SessionRead(sid string) (Store, error) { + s, err := n.delegate.SessionRead(context.Background(), sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionExist(sid string) bool { + res, _ := n.delegate.SessionExist(context.Background(), sid) + return res +} + +func (n *newToOldProviderAdapter) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := n.delegate.SessionRegenerate(context.Background(), oldsid, sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionDestroy(sid string) error { + return n.delegate.SessionDestroy(context.Background(), sid) +} + +func (n *newToOldProviderAdapter) SessionAll() int { + return n.delegate.SessionAll(context.Background()) +} + +func (n *newToOldProviderAdapter) SessionGC() { + n.delegate.SessionGC(context.Background()) +} diff --git a/pkg/adapter/session/redis/sess_redis.go b/pkg/adapter/session/redis/sess_redis.go new file mode 100644 index 00000000..6c521e50 --- /dev/null +++ b/pkg/adapter/session/redis/sess_redis.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeRedis "github.com/astaxie/beego/pkg/infrastructure/session/redis" +) + +// MaxPoolSize redis max pool size +var MaxPoolSize = beeRedis.MaxPoolSize + +// SessionStore redis session store +type SessionStore beeRedis.SessionStore + +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beeRedis.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beeRedis.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beeRedis.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { + return (*beeRedis.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { + return (*beeRedis.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeRedis.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis session provider +type Provider beeRedis.Provider + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeRedis.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*beeRedis.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*beeRedis.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*beeRedis.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*beeRedis.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_cluster/redis_cluster.go b/pkg/adapter/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..03a805e4 --- /dev/null +++ b/pkg/adapter/session/redis_cluster/redis_cluster.go @@ -0,0 +1,120 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + cluster "github.com/astaxie/beego/pkg/infrastructure/session/redis_cluster" +) + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = cluster.MaxPoolSize + +// SessionStore redis_cluster session store +type SessionStore cluster.SessionStore + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*cluster.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*cluster.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + return (*cluster.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + return (*cluster.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return (*cluster.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*cluster.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_cluster session provider +type Provider cluster.Provider + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*cluster.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*cluster.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*cluster.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*cluster.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*cluster.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..f5eb8a4f --- /dev/null +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + sentinel "github.com/astaxie/beego/pkg/infrastructure/session/redis_sentinel" +) + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = sentinel.DefaultPoolSize + +// SessionStore redis_sentinel session store +type SessionStore sentinel.SessionStore + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*sentinel.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*sentinel.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + return (*sentinel.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + return (*sentinel.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return (*sentinel.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*sentinel.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_sentinel session provider +type Provider sentinel.Provider + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*sentinel.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*sentinel.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*sentinel.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*sentinel.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*sentinel.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..7c33985f --- /dev/null +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/pkg/adapter/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + // todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) + +} diff --git a/pkg/adapter/session/sess_cookie.go b/pkg/adapter/session/sess_cookie.go new file mode 100644 index 00000000..32216040 --- /dev/null +++ b/pkg/adapter/session/sess_cookie.go @@ -0,0 +1,114 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// CookieSessionStore Cookie SessionStore +type CookieSessionStore session.CookieSessionStore + +// Set value to cookie session. +// the value are encoded as gob with hash block string. +func (st *CookieSessionStore) Set(key, value interface{}) error { + return (*session.CookieSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from cookie session +func (st *CookieSessionStore) Get(key interface{}) interface{} { + return (*session.CookieSessionStore)(st).Get(context.Background(), key) +} + +// Delete value in cookie session +func (st *CookieSessionStore) Delete(key interface{}) error { + return (*session.CookieSessionStore)(st).Delete(context.Background(), key) +} + +// Flush Clean all values in cookie session +func (st *CookieSessionStore) Flush() error { + return (*session.CookieSessionStore)(st).Flush(context.Background()) +} + +// SessionID Return id of this cookie session +func (st *CookieSessionStore) SessionID() string { + return (*session.CookieSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Write cookie session to http response cookie +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.CookieSessionStore)(st).SessionRelease(context.Background(), w) +} + +// CookieProvider Cookie session provider +type CookieProvider session.CookieProvider + +// SessionInit Init cookie session provider with max lifetime and config json. +// maxlifetime is ignored. +// json config: +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + return (*session.CookieProvider)(pder).SessionInit(context.Background(), maxlifetime, config) +} + +// SessionRead Get SessionStore in cooke. +// decode cooke string to map and put into SessionStore with sid. +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Cookie session is always existed +func (pder *CookieProvider) SessionExist(sid string) bool { + res, _ := (*session.CookieProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Implement method, no used. +func (pder *CookieProvider) SessionDestroy(sid string) error { + return (*session.CookieProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC Implement method, no used. +func (pder *CookieProvider) SessionGC() { + (*session.CookieProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll Implement method, return 0. +func (pder *CookieProvider) SessionAll() int { + return (*session.CookieProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate Implement method, no used. +func (pder *CookieProvider) SessionUpdate(sid string) error { + return (*session.CookieProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/adapter/session/sess_cookie_test.go b/pkg/adapter/session/sess_cookie_test.go new file mode 100644 index 00000000..b6726005 --- /dev/null +++ b/pkg/adapter/session/sess_cookie_test.go @@ -0,0 +1,105 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID() != session.SessionID() { + t.Fatal("get cookie session id is not the same again.") + } + + // After destroy session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID() == session.SessionID() { + t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") + } +} diff --git a/pkg/adapter/session/sess_file.go b/pkg/adapter/session/sess_file.go new file mode 100644 index 00000000..b9648998 --- /dev/null +++ b/pkg/adapter/session/sess_file.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// FileSessionStore File session store +type FileSessionStore session.FileSessionStore + +// Set value to file session +func (fs *FileSessionStore) Set(key, value interface{}) error { + return (*session.FileSessionStore)(fs).Set(context.Background(), key, value) +} + +// Get value from file session +func (fs *FileSessionStore) Get(key interface{}) interface{} { + return (*session.FileSessionStore)(fs).Get(context.Background(), key) +} + +// Delete value in file session by given key +func (fs *FileSessionStore) Delete(key interface{}) error { + return (*session.FileSessionStore)(fs).Delete(context.Background(), key) +} + +// Flush Clean all values in file session +func (fs *FileSessionStore) Flush() error { + return (*session.FileSessionStore)(fs).Flush(context.Background()) +} + +// SessionID Get file session store id +func (fs *FileSessionStore) SessionID() string { + return (*session.FileSessionStore)(fs).SessionID(context.Background()) +} + +// SessionRelease Write file session to local file with Gob string +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.FileSessionStore)(fs).SessionRelease(context.Background(), w) +} + +// FileProvider File session provider +type FileProvider session.FileProvider + +// SessionInit Init file session provider. +// savePath sets the session files path. +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.FileProvider)(fp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead Read file session by sid. +// if file is not exist, create it. +// the file path is generated from sid string. +func (fp *FileProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Check file session exist. +// it checks the file named from sid exist or not. +func (fp *FileProvider) SessionExist(sid string) bool { + res, _ := (*session.FileProvider)(fp).SessionExist(context.Background(), sid) + return res +} + +// SessionDestroy Remove all files in this save path +func (fp *FileProvider) SessionDestroy(sid string) error { + return (*session.FileProvider)(fp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle files in save path +func (fp *FileProvider) SessionGC() { + (*session.FileProvider)(fp).SessionGC(context.Background()) +} + +// SessionAll Get active file session number. +// it walks save path to count files. +func (fp *FileProvider) SessionAll() int { + return (*session.FileProvider)(fp).SessionAll(context.Background()) +} + +// SessionRegenerate Generate new sid for file session. +// it delete old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} diff --git a/pkg/adapter/session/sess_file_test.go b/pkg/adapter/session/sess_file_test.go new file mode 100644 index 00000000..4c90a3ac --- /dev/null +++ b/pkg/adapter/session/sess_file_test.go @@ -0,0 +1,336 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} diff --git a/pkg/adapter/session/sess_mem.go b/pkg/adapter/session/sess_mem.go new file mode 100644 index 00000000..818c8329 --- /dev/null +++ b/pkg/adapter/session/sess_mem.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// MemSessionStore memory session store. +// it saved sessions in a map in memory. +type MemSessionStore session.MemSessionStore + +// Set value to memory session +func (st *MemSessionStore) Set(key, value interface{}) error { + return (*session.MemSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from memory session by key +func (st *MemSessionStore) Get(key interface{}) interface{} { + return (*session.MemSessionStore)(st).Get(context.Background(), key) +} + +// Delete in memory session by key +func (st *MemSessionStore) Delete(key interface{}) error { + return (*session.MemSessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in memory session +func (st *MemSessionStore) Flush() error { + return (*session.MemSessionStore)(st).Flush(context.Background()) +} + +// SessionID get this id of memory session store +func (st *MemSessionStore) SessionID() string { + return (*session.MemSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Implement method, no used. +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.MemSessionStore)(st).SessionRelease(context.Background(), w) +} + +// MemProvider Implement the provider interface +type MemProvider session.MemProvider + +// SessionInit init memory session +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.MemProvider)(pder).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist check session store exist in memory session by sid +func (pder *MemProvider) SessionExist(sid string) bool { + res, _ := (*session.MemProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy delete session store in memory session by id +func (pder *MemProvider) SessionDestroy(sid string) error { + return (*session.MemProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC clean expired session stores in memory session +func (pder *MemProvider) SessionGC() { + (*session.MemProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll get count number of memory session +func (pder *MemProvider) SessionAll() int { + return (*session.MemProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate expand time of session store by id in memory session +func (pder *MemProvider) SessionUpdate(sid string) error { + return (*session.MemProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/adapter/session/sess_mem_test.go b/pkg/adapter/session/sess_mem_test.go new file mode 100644 index 00000000..2e8934b8 --- /dev/null +++ b/pkg/adapter/session/sess_mem_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, _ := NewManager("memory", conf) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + defer sess.SessionRelease(w) + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/pkg/adapter/session/sess_test.go b/pkg/adapter/session/sess_test.go new file mode 100644 index 00000000..aba702ca --- /dev/null +++ b/pkg/adapter/session/sess_test.go @@ -0,0 +1,51 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "testing" +) + +func Test_gob(t *testing.T) { + a := make(map[interface{}]interface{}) + a["username"] = "astaxie" + a[12] = 234 + a["user"] = User{"asta", "xie"} + b, err := EncodeGob(a) + if err != nil { + t.Error(err) + } + c, err := DecodeGob(b) + if err != nil { + t.Error(err) + } + if len(c) == 0 { + t.Error("decodeGob empty") + } + if c["username"] != "astaxie" { + t.Error("decode string error") + } + if c[12] != 234 { + t.Error("decode int error") + } + if c["user"].(User).Username != "asta" { + t.Error("decode struct error") + } +} + +type User struct { + Username string + NickName string +} diff --git a/pkg/adapter/session/sess_utils.go b/pkg/adapter/session/sess_utils.go new file mode 100644 index 00000000..3d107198 --- /dev/null +++ b/pkg/adapter/session/sess_utils.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// EncodeGob encode the obj to gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + return session.EncodeGob(obj) +} + +// DecodeGob decode data to map +func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { + return session.DecodeGob(encoded) +} diff --git a/pkg/adapter/session/session.go b/pkg/adapter/session/session.go new file mode 100644 index 00000000..eea2f90e --- /dev/null +++ b/pkg/adapter/session/session.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session provider +// +// Usage: +// import( +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package session + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// Store contains all data for one session process with specific id. +type Store interface { + Set(key, value interface{}) error // set session value + Get(key interface{}) interface{} // get session value + Delete(key interface{}) error // delete session value + SessionID() string // back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error // delete all data +} + +// Provider contains global session methods and saved SessionStores. +// it can operate a SessionStore by its id. +type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (Store, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (Store, error) + SessionDestroy(sid string) error + SessionAll() int // get all active session + SessionGC() +} + +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + session.Register(name, &oldToNewProviderAdapter{ + delegate: provide, + }) +} + +// GetProvider +func GetProvider(name string) (Provider, error) { + res, err := session.GetProvider(name) + if adt, ok := res.(*oldToNewProviderAdapter); err == nil && ok { + return adt.delegate, err + } + + return &newToOldProviderAdapter{ + delegate: res, + }, err +} + +// ManagerConfig define the session config +type ManagerConfig session.ManagerConfig + +// Manager contains Provider and its configuration. +type Manager session.Manager + +// NewManager Create new Manager with provider name and json config string. +// provider name: +// 1. cookie +// 2. file +// 3. memory +// 4. redis +// 5. mysql +// json config: +// 1. is https default false +// 2. hashfunc default sha1 +// 3. hashkey default beegosessionkey +// 4. maxage default is none +func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { + m, err := session.NewManager(provideName, (*session.ManagerConfig)(cf)) + return (*Manager)(m), err +} + +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return &newToOldProviderAdapter{ + delegate: (*session.Manager)(manager).GetProvider(), + } +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (Store, error) { + s, err := (*session.Manager)(manager).SessionStart(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Destroy session by its id in http request cookie. +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + (*session.Manager)(manager).SessionDestroy(w, r) +} + +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (Store, error) { + s, err := (*session.Manager)(manager).GetSessionStore(sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// GC Start session gc process. +// it can do gc in times after gc lifetime. +func (manager *Manager) GC() { + (*session.Manager)(manager).GC() +} + +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) Store { + s := (*session.Manager)(manager).SessionRegenerateID(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +// GetActiveSession Get all active sessions count number. +func (manager *Manager) GetActiveSession() int { + return (*session.Manager)(manager).GetActiveSession() +} + +// SetSecure Set cookie with https. +func (manager *Manager) SetSecure(secure bool) { + (*session.Manager)(manager).SetSecure(secure) +} + +// Log implement the log.Logger +type Log session.Log + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + return (*Log)(session.NewSessionLog(out)) +} diff --git a/pkg/adapter/session/ssdb/sess_ssdb.go b/pkg/adapter/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..aee3a364 --- /dev/null +++ b/pkg/adapter/session/ssdb/sess_ssdb.go @@ -0,0 +1,84 @@ +package ssdb + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeSsdb "github.com/astaxie/beego/pkg/infrastructure/session/ssdb" +) + +// Provider holds ssdb client and configs +type Provider beeSsdb.Provider + +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { + return (*beeSsdb.Provider)(p).SessionInit(context.Background(), maxLifetime, savePath) +} + +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { + res, _ := (*beeSsdb.Provider)(p).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { + return (*beeSsdb.Provider)(p).SessionDestroy(context.Background(), sid) +} + +// SessionGC not implemented +func (p *Provider) SessionGC() { + (*beeSsdb.Provider)(p).SessionGC(context.Background()) +} + +// SessionAll not implemented +func (p *Provider) SessionAll() int { + return (*beeSsdb.Provider)(p).SessionAll(context.Background()) +} + +// SessionStore holds the session information which stored in ssdb +type SessionStore beeSsdb.SessionStore + +// Set the key and value +func (s *SessionStore) Set(key, value interface{}) error { + return (*beeSsdb.SessionStore)(s).Set(context.Background(), key, value) +} + +// Get return the value by the key +func (s *SessionStore) Get(key interface{}) interface{} { + return (*beeSsdb.SessionStore)(s).Get(context.Background(), key) +} + +// Delete the key in session store +func (s *SessionStore) Delete(key interface{}) error { + return (*beeSsdb.SessionStore)(s).Delete(context.Background(), key) +} + +// Flush delete all keys and values +func (s *SessionStore) Flush() error { + return (*beeSsdb.SessionStore)(s).Flush(context.Background()) +} + +// SessionID return the sessionID +func (s *SessionStore) SessionID() string { + return (*beeSsdb.SessionStore)(s).SessionID(context.Background()) +} + +// SessionRelease Store the keyvalues into ssdb +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeSsdb.SessionStore)(s).SessionRelease(context.Background(), w) +} diff --git a/pkg/adapter/session/store_adapter.go b/pkg/adapter/session/store_adapter.go new file mode 100644 index 00000000..c1a03c38 --- /dev/null +++ b/pkg/adapter/session/store_adapter.go @@ -0,0 +1,84 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type NewToOldStoreAdapter struct { + delegate session.Store +} + +func CreateNewToOldStoreAdapter(s session.Store) Store { + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +func (n *NewToOldStoreAdapter) Set(key, value interface{}) error { + return n.delegate.Set(context.Background(), key, value) +} + +func (n *NewToOldStoreAdapter) Get(key interface{}) interface{} { + return n.delegate.Get(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) Delete(key interface{}) error { + return n.delegate.Delete(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) SessionID() string { + return n.delegate.SessionID(context.Background()) +} + +func (n *NewToOldStoreAdapter) SessionRelease(w http.ResponseWriter) { + n.delegate.SessionRelease(context.Background(), w) +} + +func (n *NewToOldStoreAdapter) Flush() error { + return n.delegate.Flush(context.Background()) +} + +type oldToNewStoreAdapter struct { + delegate Store +} + +func (o *oldToNewStoreAdapter) Set(ctx context.Context, key, value interface{}) error { + return o.delegate.Set(key, value) +} + +func (o *oldToNewStoreAdapter) Get(ctx context.Context, key interface{}) interface{} { + return o.delegate.Get(key) +} + +func (o *oldToNewStoreAdapter) Delete(ctx context.Context, key interface{}) error { + return o.delegate.Delete(key) +} + +func (o *oldToNewStoreAdapter) SessionID(ctx context.Context) string { + return o.delegate.SessionID() +} + +func (o *oldToNewStoreAdapter) SessionRelease(ctx context.Context, w http.ResponseWriter) { + o.delegate.SessionRelease(w) +} + +func (o *oldToNewStoreAdapter) Flush(ctx context.Context) error { + return o.delegate.Flush() +} diff --git a/pkg/infrastructure/session/sess_cookie.go b/pkg/infrastructure/session/sess_cookie.go index ffb19fb7..649f6510 100644 --- a/pkg/infrastructure/session/sess_cookie.go +++ b/pkg/infrastructure/session/sess_cookie.go @@ -172,7 +172,7 @@ func (pder *CookieProvider) SessionAll(context.Context) int { } // SessionUpdate Implement method, no used. -func (pder *CookieProvider) SessionUpdate(sid string) error { +func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error { return nil } diff --git a/pkg/infrastructure/session/sess_mem.go b/pkg/infrastructure/session/sess_mem.go index 9a27c331..27e24c73 100644 --- a/pkg/infrastructure/session/sess_mem.go +++ b/pkg/infrastructure/session/sess_mem.go @@ -96,7 +96,7 @@ func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, sav func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) + go pder.SessionUpdate(nil, sid) pder.lock.RUnlock() return element.Value.(*MemSessionStore), nil } @@ -123,7 +123,7 @@ func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, er func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { - go pder.SessionUpdate(oldsid) + go pder.SessionUpdate(nil, oldsid) pder.lock.RUnlock() pder.lock.Lock() element.Value.(*MemSessionStore).sid = sid @@ -181,7 +181,7 @@ func (pder *MemProvider) SessionAll(context.Context) int { } // SessionUpdate expand time of session store by id in memory session -func (pder *MemProvider) SessionUpdate(sid string) error { +func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { From 1dae2c9eb3fbe06c21639409190ec526d66e6e5e Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 22:44:31 +0800 Subject: [PATCH 23/35] Adapter: web module --- pkg/adapter/admin.go | 48 ++++ pkg/adapter/app.go | 261 ++++++++++++++++++++ pkg/adapter/beego.go | 75 ++++++ pkg/adapter/build_info.go | 27 +++ pkg/adapter/config.go | 179 ++++++++++++++ pkg/adapter/controller.go | 401 +++++++++++++++++++++++++++++++ pkg/adapter/error.go | 202 ++++++++++++++++ pkg/adapter/filter.go | 36 +++ pkg/adapter/flash.go | 63 +++++ pkg/adapter/fs.go | 35 +++ pkg/adapter/log.go | 129 ++++++++++ pkg/adapter/namespace.go | 378 +++++++++++++++++++++++++++++ pkg/adapter/policy.go | 57 +++++ pkg/adapter/router.go | 279 +++++++++++++++++++++ pkg/adapter/template.go | 108 +++++++++ pkg/adapter/templatefunc.go | 151 ++++++++++++ pkg/adapter/templatefunc_test.go | 304 +++++++++++++++++++++++ pkg/adapter/tree.go | 49 ++++ pkg/adapter/tree_test.go | 249 +++++++++++++++++++ pkg/server/web/app.go | 2 +- pkg/server/web/config.go | 3 + pkg/server/web/filter.go | 46 +++- pkg/server/web/router.go | 14 +- 23 files changed, 3080 insertions(+), 16 deletions(-) create mode 100644 pkg/adapter/admin.go create mode 100644 pkg/adapter/app.go create mode 100644 pkg/adapter/beego.go create mode 100644 pkg/adapter/build_info.go create mode 100644 pkg/adapter/config.go create mode 100644 pkg/adapter/controller.go create mode 100644 pkg/adapter/error.go create mode 100644 pkg/adapter/filter.go create mode 100644 pkg/adapter/flash.go create mode 100644 pkg/adapter/fs.go create mode 100644 pkg/adapter/log.go create mode 100644 pkg/adapter/namespace.go create mode 100644 pkg/adapter/policy.go create mode 100644 pkg/adapter/router.go create mode 100644 pkg/adapter/template.go create mode 100644 pkg/adapter/templatefunc.go create mode 100644 pkg/adapter/templatefunc_test.go create mode 100644 pkg/adapter/tree.go create mode 100644 pkg/adapter/tree_test.go diff --git a/pkg/adapter/admin.go b/pkg/adapter/admin.go new file mode 100644 index 00000000..87e7259b --- /dev/null +++ b/pkg/adapter/admin.go @@ -0,0 +1,48 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + FilterMonitorFunc = web.FilterMonitorFunc +} + +// PrintTree prints all registered routers. +func PrintTree() M { + return (M)(web.PrintTree()) +} diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go new file mode 100644 index 00000000..64280a7b --- /dev/null +++ b/pkg/adapter/app.go @@ -0,0 +1,261 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + context2 "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = (*App)(web.BeeApp) +} + +// App defines beego application with a new PatternServeMux. +type App web.App + +// NewApp returns a new beego application. +func NewApp() *App { + return (*App)(web.NewApp()) +} + +// MiddleWare function for http.Handler +type MiddleWare web.MiddleWare + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + (*web.App)(app).Run(newMws...) +} + +func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { + newMws := make([]web.MiddleWare, 0, len(mws)) + for _, old := range mws { + newMws = append(newMws, (web.MiddleWare)(old)) + } + return newMws +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + return (*App)(web.Router(rootpath, c, mappingMethods...)) +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + return (*App)(web.UnregisterFixedRoute(fixedRoute, method)) +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +// } +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + newList := oldToNewCtrlIntfs(cList) + return (*App)(web.Include(newList...)) +} + +func oldToNewCtrlIntfs(cList []ControllerInterface) []web.ControllerInterface { + newList := make([]web.ControllerInterface, 0, len(cList)) + for _, c := range cList { + newList = append(newList, c) + } + return newList +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + return (*App)(web.RESTRouter(rootpath, c)) +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + return (*App)(web.AutoRouter(c)) +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + return (*App)(web.AutoPrefix(prefix, c)) +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + return (*App)(web.Get(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + return (*App)(web.Post(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + return (*App)(web.Delete(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + return (*App)(web.Put(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + return (*App)(web.Head(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + return (*App)(web.Options(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + return (*App)(web.Patch(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + return (*App)(web.Any(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + return (*App)(web.Handler(rootpath, h, options)) +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + return (*App)(web.InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*context2.Context)(ctx)) + }, params...)) +} diff --git a/pkg/adapter/beego.go b/pkg/adapter/beego.go new file mode 100644 index 00000000..efd2d4ea --- /dev/null +++ b/pkg/adapter/beego.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + // VERSION represent beego web framework version. + VERSION = web.VERSION + + // DEV is for develop + DEV = web.DEV + // PROD is for production + PROD = web.PROD +) + +// M is Map shortcut +type M web.M + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + for _, f := range hf { + web.AddAPPStartHook(func() error { + return f() + }) + } +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + web.Run(params...) +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + web.RunWithMiddleWares(addr, newMws...) +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + web.TestBeegoInit(ap) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + web.InitBeegoBeforeTest(appConfigPath) +} diff --git a/pkg/adapter/build_info.go b/pkg/adapter/build_info.go new file mode 100644 index 00000000..1e8dacf0 --- /dev/null +++ b/pkg/adapter/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/adapter/config.go b/pkg/adapter/config.go new file mode 100644 index 00000000..1491722c --- /dev/null +++ b/pkg/adapter/config.go @@ -0,0 +1,179 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + context2 "context" + + "github.com/astaxie/beego/pkg/adapter/session" + newCfg "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/server/web" +) + +// Config is the main struct for BConfig +type Config web.Config + +// Listen holds for http and https related config +type Listen web.Listen + +// WebConfig holds web related config +type WebConfig web.WebConfig + +// SessionConfig holds session related config +type SessionConfig web.SessionConfig + +// LogConfig holds Log related config +type LogConfig web.LogConfig + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = (*Config)(web.BConfig) + AppPath = web.AppPath + + WorkPath = web.WorkPath + + AppConfig = &beegoAppConfig{innerConfig: (newCfg.Configer)(web.AppConfig)} +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + return web.LoadAppConfig(adapterName, configPath) +} + +type beegoAppConfig struct { + innerConfig newCfg.Configer +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(context2.Background(), BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(context2.Background(), key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v, err := b.innerConfig.String(context2.Background(), BConfig.RunMode+"::"+key); v != "" && err != nil { + return v + } + res, _ := b.innerConfig.String(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v, err := b.innerConfig.Strings(context2.Background(), BConfig.RunMode+"::"+key); len(v) > 0 && err != nil { + return v + } + res, _ := b.innerConfig.Strings(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int(context2.Background(), key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int64(context2.Background(), key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Bool(context2.Background(), key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Float(context2.Background(), key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(context2.Background(), key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(context2.Background(), section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(context2.Background(), filename) +} diff --git a/pkg/adapter/controller.go b/pkg/adapter/controller.go new file mode 100644 index 00000000..010add64 --- /dev/null +++ b/pkg/adapter/controller.go @@ -0,0 +1,401 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "mime/multipart" + "net/url" + + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/session" + webContext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = web.ErrAbort + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = web.GlobalControllerRouter +) + +// ControllerFilter store the filter for controller +type ControllerFilter web.ControllerFilter + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments web.ControllerFilterComments + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments web.ControllerImportComments + +// ControllerComments store the comment for the controller method +type ControllerComments web.ControllerComments + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice web.ControllerCommentsSlice + +func (p ControllerCommentsSlice) Len() int { + return (web.ControllerCommentsSlice)(p).Len() +} +func (p ControllerCommentsSlice) Less(i, j int) bool { + return (web.ControllerCommentsSlice)(p).Less(i, j) +} +func (p ControllerCommentsSlice) Swap(i, j int) { + (web.ControllerCommentsSlice)(p).Swap(i, j) +} + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller web.Controller + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface web.ControllerInterface + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + (*web.Controller)(c).Init((*webContext.Context)(ctx), controllerName, actionName, app) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() { + (*web.Controller)(c).Prepare() +} + +// Finish runs after request function execution. +func (c *Controller) Finish() { + (*web.Controller)(c).Finish() +} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + (*web.Controller)(c).Get() +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + (*web.Controller)(c).Post() +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + (*web.Controller)(c).Delete() +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + (*web.Controller)(c).Put() +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + (*web.Controller)(c).Head() +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + (*web.Controller)(c).Patch() +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + (*web.Controller)(c).Options() +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + (*web.Controller)(c).Trace() +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + return (*web.Controller)(c).HandlerFunc(fnname) +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() { + (*web.Controller)(c).URLMapping() +} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + (*web.Controller)(c).Mapping(method, fn) +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + return (*web.Controller)(c).Render() +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + return (*web.Controller)(c).RenderString() +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + return (*web.Controller)(c).RenderBytes() +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + (*web.Controller)(c).Redirect(url, code) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + (*web.Controller)(c).SetData(data) +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + (*web.Controller)(c).Abort(code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + (*web.Controller)(c).CustomAbort(status, body) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + (*web.Controller)(c).StopRun() +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + return (*web.Controller)(c).URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + (*web.Controller)(c).ServeJSON(encoding...) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + (*web.Controller)(c).ServeJSONP() +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + (*web.Controller)(c).ServeXML() +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + (*web.Controller)(c).ServeYAML() +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + (*web.Controller)(c).ServeFormatted(encoding...) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + return (*web.Controller)(c).Input() +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return (*web.Controller)(c).ParseForm(obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + return (*web.Controller)(c).GetString(key, def...) +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + return (*web.Controller)(c).GetStrings(key, def...) +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + return (*web.Controller)(c).GetInt(key, def...) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + return (*web.Controller)(c).GetInt8(key, def...) +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + return (*web.Controller)(c).GetUint8(key, def...) +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + return (*web.Controller)(c).GetInt16(key, def...) +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + return (*web.Controller)(c).GetUint16(key, def...) +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + return (*web.Controller)(c).GetInt32(key, def...) +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + return (*web.Controller)(c).GetUint32(key, def...) +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + return (*web.Controller)(c).GetInt64(key, def...) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + return (*web.Controller)(c).GetUint64(key, def...) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + return (*web.Controller)(c).GetBool(key, def...) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + return (*web.Controller)(c).GetFloat(key, def...) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return (*web.Controller)(c).GetFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + return (*web.Controller)(c).GetFiles(key) +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + return (*web.Controller)(c).SaveToFile(fromfile, tofile) +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + s := (*web.Controller)(c).StartSession() + return session.CreateNewToOldStoreAdapter(s) +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + (*web.Controller)(c).SetSession(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + return (*web.Controller)(c).GetSession(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + (*web.Controller)(c).DelSession(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + (*web.Controller)(c).SessionRegenerateID() +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + (*web.Controller)(c).DestroySession() +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return (*web.Controller)(c).IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return (*web.Controller)(c).GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*web.Controller)(c).SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + return (*web.Controller)(c).XSRFToken() +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + return (*web.Controller)(c).CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return (*web.Controller)(c).XSRFFormHTML() +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return (*web.Controller)(c).GetControllerAndAction() +} diff --git a/pkg/adapter/error.go b/pkg/adapter/error.go new file mode 100644 index 00000000..4f08aa8c --- /dev/null +++ b/pkg/adapter/error.go @@ -0,0 +1,202 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = web.ErrorMaps + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + return (*App)(web.ErrorHandler(code, h)) +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + return (*App)(web.ErrorController(c)) +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + web.Exception(errCode, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/filter.go b/pkg/adapter/filter.go new file mode 100644 index 00000000..cafed773 --- /dev/null +++ b/pkg/adapter/filter.go @@ -0,0 +1,36 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter web.FilterRouter + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + return (*web.FilterRouter)(f).ValidRouter(url, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/flash.go b/pkg/adapter/flash.go new file mode 100644 index 00000000..e5e1c187 --- /dev/null +++ b/pkg/adapter/flash.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData web.FlashData + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return (*FlashData)(web.NewFlash()) +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + (*web.FlashData)(fd).Set(key, msg, args) +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + (*web.FlashData)(fd).Success(msg, args...) +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + (*web.FlashData)(fd).Notice(msg, args...) +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + (*web.FlashData)(fd).Warning(msg, args...) +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + (*web.FlashData)(fd).Error(msg, args...) +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + (*web.FlashData)(fd).Store((*web.Controller)(c)) +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + return (*FlashData)(web.ReadFromRequest((*web.Controller)(c))) +} diff --git a/pkg/adapter/fs.go b/pkg/adapter/fs.go new file mode 100644 index 00000000..07054ca3 --- /dev/null +++ b/pkg/adapter/fs.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "path/filepath" + + "github.com/astaxie/beego/pkg/server/web" +) + +type FileSystem web.FileSystem + +func (d FileSystem) Open(name string) (http.File, error) { + return (web.FileSystem)(d).Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + return web.Walk(fs, root, walkFn) +} diff --git a/pkg/adapter/log.go b/pkg/adapter/log.go new file mode 100644 index 00000000..d9ff6e0c --- /dev/null +++ b/pkg/adapter/log.go @@ -0,0 +1,129 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "strings" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + + webLog "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = webLog.LevelEmergency + LevelAlert = webLog.LevelAlert + LevelCritical = webLog.LevelCritical + LevelError = webLog.LevelError + LevelWarning = webLog.LevelWarning + LevelNotice = webLog.LevelNotice + LevelInformational = webLog.LevelInformational + LevelDebug = webLog.LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/adapter/namespace.go b/pkg/adapter/namespace.go new file mode 100644 index 00000000..609402cf --- /dev/null +++ b/pkg/adapter/namespace.go @@ -0,0 +1,378 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + adtContext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +type namespaceCond func(*adtContext.Context) bool + +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) + +// Namespace is store all the info +type Namespace web.Namespace + +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { + nps := oldToNewLinkNs(params) + return (*Namespace)(web.NewNamespace(prefix, nps...)) +} + +func oldToNewLinkNs(params []LinkNamespace) []web.LinkNamespace { + nps := make([]web.LinkNamespace, 0, len(params)) + for _, p := range params { + nps = append(nps, func(namespace *web.Namespace) { + p((*Namespace)(namespace)) + }) + } + return nps +} + +// Cond set condition function +// if cond return true can run this namespace, else can't +// usage: +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.me" { +// return true +// } +// return false +// }) +// Cond as the first filter +func (n *Namespace) Cond(cond namespaceCond) *Namespace { + (*web.Namespace)(n).Cond(func(context *context.Context) bool { + return cond((*adtContext.Context)(context)) + }) + return n +} + +// Filter add filter in the Namespace +// action has before & after +// FilterFunc +// usage: +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) +func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { + nfs := oldToNewFilter(filter) + (*web.Namespace)(n).Filter(action, nfs...) + return n +} + +func oldToNewFilter(filter []FilterFunc) []web.FilterFunc { + nfs := make([]web.FilterFunc, 0, len(filter)) + for _, f := range filter { + nfs = append(nfs, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } + return nfs +} + +// Router same as beego.Rourer +// refer: https://godoc.org/github.com/astaxie/beego#Router +func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { + (*web.Namespace)(n).Router(rootpath, c, mappingMethods...) + return n +} + +// AutoRouter same as beego.AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoRouter(c) + return n +} + +// AutoPrefix same as beego.AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoPrefix(prefix, c) + return n +} + +// Get same as beego.Get +// refer: https://godoc.org/github.com/astaxie/beego#Get +func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Get(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Post same as beego.Post +// refer: https://godoc.org/github.com/astaxie/beego#Post +func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Delete same as beego.Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete +func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Delete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Put same as beego.Put +// refer: https://godoc.org/github.com/astaxie/beego#Put +func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Put(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Head same as beego.Head +// refer: https://godoc.org/github.com/astaxie/beego#Head +func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Head(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Options same as beego.Options +// refer: https://godoc.org/github.com/astaxie/beego#Options +func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Options(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Patch same as beego.Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch +func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Patch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Any same as beego.Any +// refer: https://godoc.org/github.com/astaxie/beego#Any +func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Any(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Handler same as beego.Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler +func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { + (*web.Namespace)(n).Handler(rootpath, h) + return n +} + +// Include add include class +// refer: https://godoc.org/github.com/astaxie/beego#Include +func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { + nL := oldToNewCtrlIntfs(cList) + (*web.Namespace)(n).Include(nL...) + return n +} + +// Namespace add nest Namespace +// usage: +// ns := beego.NewNamespace(“/v1”). +// Namespace( +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +// ) +func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { + nns := oldToNewNs(ns) + (*web.Namespace)(n).Namespace(nns...) + return n +} + +func oldToNewNs(ns []*Namespace) []*web.Namespace { + nns := make([]*web.Namespace, 0, len(ns)) + for _, n := range ns { + nns = append(nns, (*web.Namespace)(n)) + } + return nns +} + +// AddNamespace register Namespace into beego.Handler +// support multi Namespace +func AddNamespace(nl ...*Namespace) { + nnl := oldToNewNs(nl) + web.AddNamespace(nnl...) +} + +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { + return func(namespace *Namespace) { + web.NSCond(func(b *context.Context) bool { + return cond((*adtContext.Context)(b)) + }) + } +} + +// NSBefore Namespace BeforeRouter filter +func NSBefore(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSBefore(nfs...) + } +} + +// NSAfter add Namespace FinishRouter filter +func NSAfter(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSAfter(nfs...) + } +} + +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewCtrlIntfs(cList) + web.NSInclude(nfs...) + } +} + +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + return func(namespace *Namespace) { + web.Router(rootpath, c, mappingMethods...) + } +} + +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSGet(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSHead(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPut(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSDelete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSAny(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSOptions(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPatch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoRouter(c) + } +} + +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoPrefix(prefix, c) + } +} + +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + return func(ns *Namespace) { + nps := oldToNewLinkNs(params) + web.NSNamespace(prefix, nps...) + } +} + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + web.NSHandler(rootpath, h) + } +} diff --git a/pkg/adapter/policy.go b/pkg/adapter/policy.go new file mode 100644 index 00000000..f3759c76 --- /dev/null +++ b/pkg/adapter/policy.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindPolicy Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + pf := (*web.ControllerRegister)(p).FindPolicy((*beecontext.Context)(cont)) + npf := newToOldPolicyFunc(pf) + return npf +} + +func newToOldPolicyFunc(pf []web.PolicyFunc) []PolicyFunc { + npf := make([]PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *context.Context) { + f((*beecontext.Context)(c)) + }) + } + return npf +} + +func oldToNewPolicyFunc(pf []PolicyFunc) []web.PolicyFunc { + npf := make([]web.PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *beecontext.Context) { + f((*context.Context)(c)) + }) + } + return npf +} + +// Policy Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + pf := oldToNewPolicyFunc(policy) + web.Policy(pattern, method, pf...) +} diff --git a/pkg/adapter/router.go b/pkg/adapter/router.go new file mode 100644 index 00000000..5a36fbee --- /dev/null +++ b/pkg/adapter/router.go @@ -0,0 +1,279 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "time" + + beecontext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// default filter execution points +const ( + BeforeStatic = web.BeforeStatic + BeforeRouter = web.BeforeRouter + BeforeExec = web.BeforeExec + AfterExec = web.AfterExec + FinishRouter = web.FinishRouter +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = web.HTTPMETHOD + + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &newToOldFtHdlAdapter{ + delegate: web.DefaultAccessLogFilter, + } +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +type newToOldFtHdlAdapter struct { + delegate web.FilterHandler +} + +func (n *newToOldFtHdlAdapter) Filter(ctx *beecontext.Context) bool { + return n.delegate.Filter((*context.Context)(ctx)) +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + web.ExceptMethodAppend(action) +} + +// ControllerInfo holds information about the controller. +type ControllerInfo web.ControllerInfo + +func (c *ControllerInfo) GetPattern() string { + return (*web.ControllerInfo)(c).GetPattern() +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister web.ControllerRegister + +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + return (*ControllerRegister)(web.NewControllerRegister()) +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + nls := oldToNewCtrlIntfs(cList) + (*web.ControllerRegister)(p).Include(nls...) +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return (*beecontext.Context)((*web.ControllerRegister)(p).GetContext()) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + (*web.ControllerRegister)(p).GiveBackContext((*context.Context)(ctx)) +} + +// Get add get method +// usage: +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Get(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Post add post method +// usage: +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Post(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Put add put method +// usage: +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Put(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Delete add delete method +// usage: +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Delete(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Head add head method +// usage: +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Head(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Patch add patch method +// usage: +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Patch(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Options add options method +// usage: +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Options(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Any add all method +// usage: +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Any(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// AddMethod add http method router +// usage: +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).AddMethod(method, pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + (*web.ControllerRegister)(p).Handler(pattern, h, options) +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainContorlller{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + (*web.ControllerRegister)(p).AddAuto(c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + (*web.ControllerRegister)(p).AddAutoPrefix(prefix, c) +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + opts := oldToNewFilterOpts(params) + return (*web.ControllerRegister)(p).InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*beecontext.Context)(ctx)) + }, opts...) +} + +func oldToNewFilterOpts(params []bool) []web.FilterOpt { + opts := make([]web.FilterOpt, 0, 4) + if len(params) > 0 { + opts = append(opts, web.WithReturnOnOutput(params[0])) + } + if len(params) > 1 { + opts = append(opts, web.WithResetParams(params[1])) + } + return opts +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + return (*web.ControllerRegister)(p).URLFor(endpoint, values...) +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + (*web.ControllerRegister)(p).ServeHTTP(rw, r) +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(ctx *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + r, ok := (*web.ControllerRegister)(p).FindRouter((*context.Context)(ctx)) + return (*ControllerInfo)(r), ok +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + web.LogAccess((*context.Context)(ctx), startTime, statusCode) +} diff --git a/pkg/adapter/template.go b/pkg/adapter/template.go new file mode 100644 index 00000000..1f943caf --- /dev/null +++ b/pkg/adapter/template.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "io" + "net/http" + + "github.com/astaxie/beego/pkg/server/web" +) + +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return web.ExecuteTemplate(wr, name, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { + return web.ExecuteViewPathTemplate(wr, name, viewPath, data) +} + +// AddFuncMap let user to register a func in the template. +func AddFuncMap(key string, fn interface{}) error { + return web.AddFuncMap(key, fn) +} + +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + +type templateFile struct { + root string + files map[string][]string +} + +// HasTemplateExt return this path contains supported template extension of beego or not. +func HasTemplateExt(paths string) bool { + return web.HasTemplateExt(paths) +} + +// AddTemplateExt add new extension for template. +func AddTemplateExt(ext string) { + web.AddTemplateExt(ext) +} + +// AddViewPath adds a new path to the supported view paths. +// Can later be used by setting a controller ViewPath to this folder +// will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + return web.AddViewPath(viewPath) +} + +// BuildTemplate will build all template files in a directory. +// it makes beego can render any template file in view directory. +func BuildTemplate(dir string, files ...string) error { + return web.BuildTemplate(dir, files...) +} + +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} + +// SetTemplateFSFunc set default filesystem function +func SetTemplateFSFunc(fnt templateFSFunc) { + web.SetTemplateFSFunc(func() http.FileSystem { + return fnt() + }) +} + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + return (*App)(web.SetViewsPath(path)) +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + return (*App)(web.SetStaticPath(url, path)) +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + return (*App)(web.DelStaticPath(url)) +} + +// AddTemplateEngine add a new templatePreProcessor which support extension +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + return (*App)(web.AddTemplateEngine(extension, func(root, path string, funcs template.FuncMap) (*template.Template, error) { + return fn(root, path, funcs) + })) +} diff --git a/pkg/adapter/templatefunc.go b/pkg/adapter/templatefunc.go new file mode 100644 index 00000000..5130d590 --- /dev/null +++ b/pkg/adapter/templatefunc.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// Substr returns the substr from start to length. +func Substr(s string, start, length int) string { + return web.Substr(s, start, length) +} + +// HTML2str returns escaping text convert from html. +func HTML2str(html string) string { + return web.HTML2str(html) +} + +// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +func DateFormat(t time.Time, layout string) (datestring string) { + return web.DateFormat(t, layout) +} + +// DateParse Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + return web.DateParse(dateString, format) +} + +// Date takes a PHP like date func to Go's time format. +func Date(t time.Time, format string) string { + return web.Date(t, format) +} + +// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. +// Whitespace is trimmed. Used by the template parser as "eq". +func Compare(a, b interface{}) (equal bool) { + return web.Compare(a, b) +} + +// CompareNot !Compare +func CompareNot(a, b interface{}) (equal bool) { + return web.CompareNot(a, b) +} + +// NotNil the same as CompareNot +func NotNil(a interface{}) (isNil bool) { + return web.NotNil(a) +} + +// GetConfig get the Appconfig +func GetConfig(returnType, key string, defaultVal interface{}) (interface{}, error) { + return web.GetConfig(returnType, key, defaultVal) +} + +// Str2html Convert string to template.HTML type. +func Str2html(raw string) template.HTML { + return web.Str2html(raw) +} + +// Htmlquote returns quoted html string. +func Htmlquote(text string) string { + return web.Htmlquote(text) +} + +// Htmlunquote returns unquoted html string. +func Htmlunquote(text string) string { + return web.Htmlunquote(text) +} + +// URLFor returns url string with another registered controller handler with params. +// usage: +// +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe +// +// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +func URLFor(endpoint string, values ...interface{}) string { + return web.URLFor(endpoint, values...) +} + +// AssetsJs returns script tag with src string. +func AssetsJs(text string) template.HTML { + return web.AssetsJs(text) +} + +// AssetsCSS returns stylesheet link tag with src string. +func AssetsCSS(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + return web.ParseForm(form, obj) +} + +// RenderForm will render object to form html. +// obj must be a struct pointer. +func RenderForm(obj interface{}) template.HTML { + return web.RenderForm(obj) +} + +// MapGet getting value from map by keys +// usage: +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } +// +// {{ map_get m "a" }} // return 1 +// {{ map_get m 1 "c" }} // return 4 +func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { + return web.MapGet(arg1, arg2...) +} diff --git a/pkg/adapter/templatefunc_test.go b/pkg/adapter/templatefunc_test.go new file mode 100644 index 00000000..f5113606 --- /dev/null +++ b/pkg/adapter/templatefunc_test.go @@ -0,0 +1,304 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "testing" + "time" +) + +func TestSubstr(t *testing.T) { + s := `012345` + if Substr(s, 0, 2) != "01" { + t.Error("should be equal") + } + if Substr(s, 0, 100) != "012345" { + t.Error("should be equal") + } + if Substr(s, 12, 100) != "012345" { + t.Error("should be equal") + } +} + +func TestHtml2str(t *testing.T) { + h := `<123> 123\n + + + \n` + if HTML2str(h) != "123\\n\n\\n" { + t.Error("should be equal") + } +} + +func TestDateFormat(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } +} + +func TestDate(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } + if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { + t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) + } + if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { + t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) + } + if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { + t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) + } +} + +func TestCompareRelated(t *testing.T) { + if !Compare("abc", "abc") { + t.Error("should be equal") + } + if Compare("abc", "aBc") { + t.Error("should be not equal") + } + if !Compare("1", 1) { + t.Error("should be equal") + } + if CompareNot("abc", "abc") { + t.Error("should be equal") + } + if !CompareNot("abc", "aBc") { + t.Error("should be not equal") + } + if !NotNil("a string") { + t.Error("should not be nil") + } +} + +func TestHtmlquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlquote(s) != h { + t.Error("should be equal") + } +} + +func TestHtmlunquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlunquote(h) != s { + t.Error("should be equal") + } +} + +func TestParseForm(t *testing.T) { + type ExtendInfo struct { + Hobby []string `form:"hobby"` + Memo string + } + + type OtherInfo struct { + Organization string `form:"organization"` + Title string `form:"title"` + ExtendInfo + } + + type user struct { + ID int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` + OtherInfo + } + + u := user{} + form := url.Values{ + "ID": []string{"1"}, + "-": []string{"1"}, + "tag": []string{"no"}, + "username": []string{"test"}, + "age": []string{"40"}, + "Email": []string{"test@gmail.com"}, + "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, + "organization": []string{"beego"}, + "title": []string{"CXO"}, + "hobby": []string{"", "Basketball", "Football"}, + "memo": []string{"nothing"}, + } + if err := ParseForm(form, u); err == nil { + t.Fatal("nothing will be changed") + } + if err := ParseForm(form, &u); err != nil { + t.Fatal(err) + } + if u.ID != 0 { + t.Errorf("ID should equal 0 but got %v", u.ID) + } + if len(u.tag) != 0 { + t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) + } + if u.Name.(string) != "test" { + t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) + } + if u.Age != 40 { + t.Errorf("Age should equal 40 but got %v", u.Age) + } + if u.Email != "test@gmail.com" { + t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) + } + if u.Intro != "I am an engineer!" { + t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) + } + if !u.StrBool { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } + if u.Organization != "beego" { + t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) + } + if u.Title != "CXO" { + t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) + } + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) + } + if len(u.Memo) != 0 { + t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) + } +} + +func TestRenderForm(t *testing.T) { + type user struct { + ID int `form:"-"` + Name interface{} `form:"username"` + Age int `form:"age,text,年龄:"` + Sex string + Email []string + Intro string `form:",textarea"` + Ignored string `form:"-"` + } + + u := user{Name: "test", Intro: "Some Text"} + output := RenderForm(u) + if output != template.HTML("") { + t.Errorf("output should be empty but got %v", output) + } + output = RenderForm(&u) + result := template.HTML( + `Name:
` + + `年龄:
` + + `Sex:
` + + `Intro: `) + if output != result { + t.Errorf("output should equal `%v` but got `%v`", result, output) + } +} + +func TestMapGet(t *testing.T) { + // test one level map + m1 := map[string]int64{ + "a": 1, + "1": 2, + } + + if res, err := MapGet(m1, "a"); err == nil { + if res.(int64) != 1 { + t.Errorf("Should return 1, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, "1"); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, 1); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 2 level map + m2 := M{ + "1": map[string]float64{ + "2": 3.5, + }, + } + + if res, err := MapGet(m2, 1, 2); err == nil { + if res.(float64) != 3.5 { + t.Errorf("Should return 3.5, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 5 level map + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ + "5": 1.2, + }, + }, + }, + }, + } + + if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { + if res.(float64) != 1.2 { + t.Errorf("Should return 1.2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // check whether element not exists in map + if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { + if res != nil { + t.Errorf("Should return nil, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } +} diff --git a/pkg/adapter/tree.go b/pkg/adapter/tree.go new file mode 100644 index 00000000..2e3cd0d0 --- /dev/null +++ b/pkg/adapter/tree.go @@ -0,0 +1,49 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Tree has three elements: FixRouter/wildcard/leaves +// fixRouter stores Fixed Router +// wildcard stores params +// leaves store the endpoint information +type Tree web.Tree + +// NewTree return a new Tree +func NewTree() *Tree { + return (*Tree)(web.NewTree()) +} + +// AddTree will add tree to the exist Tree +// prefix should has no params +func (t *Tree) AddTree(prefix string, tree *Tree) { + (*web.Tree)(t).AddTree(prefix, (*web.Tree)(tree)) +} + +// AddRouter call addseg function +func (t *Tree) AddRouter(pattern string, runObject interface{}) { + (*web.Tree)(t).AddRouter(pattern, runObject) +} + +// Match router to runObject & params +func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { + return (*web.Tree)(t).Match(pattern, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/tree_test.go b/pkg/adapter/tree_test.go new file mode 100644 index 00000000..309ed072 --- /dev/null +++ b/pkg/adapter/tree_test.go @@ -0,0 +1,249 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +type testinfo struct { + url string + requesturl string + params map[string]string +} + +var routers []testinfo + +func init() { + routers = make([]testinfo, 0) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) + routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) + routers = append(routers, testinfo{"/", "/", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) + routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) + routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) + routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) + routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) + routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) + routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", + "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", + map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) + routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", + "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", + map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) + routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) + routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + tr := NewTree() + tr.AddRouter(r.url, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requesturl, ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) + } + } + } + } +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} diff --git a/pkg/server/web/app.go b/pkg/server/web/app.go index e61084a5..ad3ff663 100644 --- a/pkg/server/web/app.go +++ b/pkg/server/web/app.go @@ -199,7 +199,7 @@ func (app *App) Run(mws ...MiddleWare) { pool.AppendCertsFromPEM(data) app.Server.TLSConfig = &tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: tls.ClientAuthType(BConfig.Listen.ClientAuth), } } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index bf8db30e..6e69a2fb 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -16,6 +16,7 @@ package web import ( context2 "context" + "crypto/tls" "fmt" "os" "path/filepath" @@ -72,6 +73,7 @@ type Listen struct { AdminPort int EnableFcgi bool EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O + ClientAuth int } // WebConfig holds web related config @@ -234,6 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, + ClientAuth: int(tls.RequireAndVerifyClientCert), }, WebConfig: WebConfig{ AutoRender: true, diff --git a/pkg/server/web/filter.go b/pkg/server/web/filter.go index 8d3acb24..e10faafc 100644 --- a/pkg/server/web/filter.go +++ b/pkg/server/web/filter.go @@ -43,24 +43,26 @@ type FilterRouter struct { // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter { +func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { mr := &FilterRouter{ tree: NewTree(), pattern: pattern, filterFunc: filter, returnOnOutput: true, } - if !routerCaseSensitive { + + fos := &filterOpts{} + + for _, o := range opts { + o(fos) + } + + if !fos.routerCaseSensitive { mr.pattern = strings.ToLower(pattern) } - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } + mr.returnOnOutput = fos.returnOnOutput + mr.resetParams = fos.resetParams mr.tree.AddRouter(pattern, true) return mr } @@ -103,3 +105,29 @@ func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { } return false } + +type filterOpts struct { + returnOnOutput bool + resetParams bool + routerCaseSensitive bool +} + +type FilterOpt func(opts *filterOpts) + +func WithReturnOnOutput(ret bool) FilterOpt { + return func(opts *filterOpts) { + opts.returnOnOutput = ret + } +} + +func WithResetParams(reset bool) FilterOpt { + return func(opts *filterOpts) { + opts.resetParams = reset + } +} + +func WithCaseSensitive(sensitive bool) FilterOpt { + return func(opts *filterOpts) { + opts.routerCaseSensitive = sensitive + } +} diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index c3eddd29..3dd19a6f 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -148,7 +148,7 @@ func NewControllerRegister() *ControllerRegister { }, }, } - res.chainRoot = newFilterRouter("/*", false, res.serveHttp) + res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } @@ -262,7 +262,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { for _, f := range a.Filters { - p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) @@ -452,8 +452,9 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...) +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) error { + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + mr := newFilterRouter(pattern, filter, opts...) return p.insertFilterRouter(pos, mr) } @@ -468,10 +469,11 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // // do something // } // } -func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params ...bool) { +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { root := p.chainRoot filterFunc := chain(root.filterFunc) - p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) p.chainRoot.next = root } From f1950482c2c0ee8e6e90ad320245f7130ab9cf4e Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 16:54:05 +0800 Subject: [PATCH 24/35] Adapter: plugin --- pkg/adapter/plugins/apiauth/apiauth.go | 94 ++++++++ pkg/adapter/plugins/apiauth/apiauth_test.go | 20 ++ pkg/adapter/plugins/auth/basic.go | 81 +++++++ pkg/adapter/plugins/authz/authz.go | 80 +++++++ pkg/adapter/plugins/authz/authz_model.conf | 14 ++ pkg/adapter/plugins/authz/authz_policy.csv | 7 + pkg/adapter/plugins/authz/authz_test.go | 108 +++++++++ pkg/adapter/plugins/cors/cors.go | 71 ++++++ pkg/adapter/plugins/cors/cors_test.go | 253 ++++++++++++++++++++ pkg/server/web/filter/apiauth/apiauth.go | 5 - 10 files changed, 728 insertions(+), 5 deletions(-) create mode 100644 pkg/adapter/plugins/apiauth/apiauth.go create mode 100644 pkg/adapter/plugins/apiauth/apiauth_test.go create mode 100644 pkg/adapter/plugins/auth/basic.go create mode 100644 pkg/adapter/plugins/authz/authz.go create mode 100644 pkg/adapter/plugins/authz/authz_model.conf create mode 100644 pkg/adapter/plugins/authz/authz_policy.csv create mode 100644 pkg/adapter/plugins/authz/authz_test.go create mode 100644 pkg/adapter/plugins/cors/cors.go create mode 100644 pkg/adapter/plugins/cors/cors_test.go diff --git a/pkg/adapter/plugins/apiauth/apiauth.go b/pkg/adapter/plugins/apiauth/apiauth.go new file mode 100644 index 00000000..ed43f8a0 --- /dev/null +++ b/pkg/adapter/plugins/apiauth/apiauth.go @@ -0,0 +1,94 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package apiauth provides handlers to enable apiauth support. +// +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/apiauth" +// ) +// +// func main(){ +// // apiauth every request +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.Run() +// } +// +// Advanced Usage: +// +// func getAppSecret(appid string) string { +// // get appsecret by appid +// // maybe store in configure, maybe in database +// } +// +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) +// +// Information: +// +// In the request user should include these params in the query +// +// 1. appid +// +// appid is assigned to the application +// +// 2. signature +// +// get the signature use apiauth.Signature() +// +// when you send to server remember use url.QueryEscape() +// +// 3. timestamp: +// +// send the request time, the format is yyyy-mm-dd HH:ii:ss +// +package apiauth + +import ( + "net/url" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/apiauth" +) + +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret apiauth.AppIDToAppSecret + +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { + f := apiauth.APIBasicAuth(appid, appkey) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { + ft := apiauth.APISecretAuth(apiauth.AppIDToAppSecret(f), timeout) + return func(ctx *context.Context) { + ft((*beecontext.Context)(ctx)) + } +} + +// Signature used to generate signature with the appsecret/method/params/RequestURI +func Signature(appsecret, method string, params url.Values, requestURL string) string { + return apiauth.Signature(appsecret, method, params, requestURL) +} diff --git a/pkg/adapter/plugins/apiauth/apiauth_test.go b/pkg/adapter/plugins/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/pkg/adapter/plugins/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/pkg/adapter/plugins/auth/basic.go b/pkg/adapter/plugins/auth/basic.go new file mode 100644 index 00000000..7a9cd326 --- /dev/null +++ b/pkg/adapter/plugins/auth/basic.go @@ -0,0 +1,81 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package auth provides handlers to enable basic auth support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/auth" +// ) +// +// func main(){ +// // authenticate every request +// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func SecretAuth(username, password string) bool { +// return username == "astaxie" && password == "helloBeego" +// } +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") +// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) +package auth + +import ( + "net/http" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/auth" +) + +// Basic is the http basic auth +func Basic(username string, password string) beego.FilterFunc { + return func(c *context.Context) { + f := auth.Basic(username, password) + f((*beecontext.Context)(c)) + } +} + +// NewBasicAuthenticator return the BasicAuth +func NewBasicAuthenticator(secrets SecretProvider, realm string) beego.FilterFunc { + f := auth.NewBasicAuthenticator(auth.SecretProvider(secrets), realm) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// SecretProvider is the SecretProvider function +type SecretProvider auth.SecretProvider + +// BasicAuth store the SecretProvider and Realm +type BasicAuth auth.BasicAuth + +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries +func (a *BasicAuth) CheckAuth(r *http.Request) string { + return (*auth.BasicAuth)(a).CheckAuth(r) +} + +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). +func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { + (*auth.BasicAuth)(a).RequireAuth(w, r) +} diff --git a/pkg/adapter/plugins/authz/authz.go b/pkg/adapter/plugins/authz/authz.go new file mode 100644 index 00000000..c38be9cb --- /dev/null +++ b/pkg/adapter/plugins/authz/authz.go @@ -0,0 +1,80 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/casbin/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "net/http" + + "github.com/casbin/casbin" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/authz" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + f := authz.NewAuthorizer(e) + return func(context *context.Context) { + f((*beecontext.Context)(context)) + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer authz.BasicAuthorizer + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + return (*authz.BasicAuthorizer)(a).GetUserName(r) +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + return (*authz.BasicAuthorizer)(a).CheckPermission(r) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + (*authz.BasicAuthorizer)(a).RequirePermission(w) +} diff --git a/pkg/adapter/plugins/authz/authz_model.conf b/pkg/adapter/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/pkg/adapter/plugins/authz/authz_policy.csv b/pkg/adapter/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/pkg/adapter/plugins/authz/authz_test.go b/pkg/adapter/plugins/authz/authz_test.go new file mode 100644 index 00000000..ddbda5f4 --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_test.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "net/http" + "net/http/httptest" + "testing" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/plugins/auth" + "github.com/casbin/casbin" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/pkg/adapter/plugins/cors/cors.go b/pkg/adapter/plugins/cors/cors.go new file mode 100644 index 00000000..65af8b8f --- /dev/null +++ b/pkg/adapter/plugins/cors/cors.go @@ -0,0 +1,71 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cors provides handlers to enable CORS support. +// Usage +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/cors" +// ) +// +// func main() { +// // CORS for https://foo.* origins, allowing: +// // - PUT and PATCH methods +// // - Origin header +// // - Credentials share +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ +// AllowOrigins: []string{"https://*.foo.com"}, +// AllowMethods: []string{"PUT", "PATCH"}, +// AllowHeaders: []string{"Origin"}, +// ExposeHeaders: []string{"Content-Length"}, +// AllowCredentials: true, +// })) +// beego.Run() +// } +package cors + +import ( + beego "github.com/astaxie/beego/pkg/adapter" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/cors" + + "github.com/astaxie/beego/pkg/adapter/context" +) + +// Options represents Access Control options. +type Options cors.Options + +// Header converts options into CORS headers. +func (o *Options) Header(origin string) (headers map[string]string) { + return (*cors.Options)(o).Header(origin) +} + +// PreflightHeader converts options into CORS headers for a preflight response. +func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { + return (*cors.Options)(o).PreflightHeader(origin, rMethod, rHeaders) +} + +// IsOriginAllowed looks up if the origin matches one of the patterns +// generated from Options.AllowOrigins patterns. +func (o *Options) IsOriginAllowed(origin string) bool { + return (*cors.Options)(o).IsOriginAllowed(origin) +} + +// Allow enables CORS for requests those match the provided options. +func Allow(opts *Options) beego.FilterFunc { + f := cors.Allow((*cors.Options)(opts)) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} diff --git a/pkg/adapter/plugins/cors/cors_test.go b/pkg/adapter/plugins/cors/cors_test.go new file mode 100644 index 00000000..34039143 --- /dev/null +++ b/pkg/adapter/plugins/cors/cors_test.go @@ -0,0 +1,253 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cors + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header +type HTTPHeaderGuardRecorder struct { + *httptest.ResponseRecorder + savedHeaderMap http.Header +} + +// NewRecorder return HttpHeaderGuardRecorder +func NewRecorder() *HTTPHeaderGuardRecorder { + return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} +} + +func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { + gr.ResponseRecorder.WriteHeader(code) + gr.savedHeaderMap = gr.ResponseRecorder.Header() +} + +func (gr *HTTPHeaderGuardRecorder) Header() http.Header { + if gr.savedHeaderMap != nil { + // headers were written. clone so we don't get updates + clone := make(http.Header) + for k, v := range gr.savedHeaderMap { + clone[k] = v + } + return clone + } + return gr.ResponseRecorder.Header() +} + +func Test_AllowAll(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { + t.Errorf("Allow-Origin header should be *") + } +} + +func Test_AllowRegexMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://bar.foo.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != origin { + t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) + } +} + +func Test_AllowRegexNoMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://ww.foo.com.evil.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != "" { + t.Errorf("Allow-Origin header should not exist, found %v", headerValue) + } +} + +func Test_OtherHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + ExposeHeaders: []string{"Content-Length", "Hello"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) + methodsVal := recorder.HeaderMap.Get(headerAllowMethods) + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) + maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) + + if credentialsVal != "true" { + t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) + } + + if methodsVal != "PATCH,GET" { + t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) + } + + if headersVal != "Origin,X-whatever" { + t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) + } + + if exposedHeadersVal != "Content-Length,Hello" { + t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) + } + + if maxAgeVal != "300" { + t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) + } +} + +func Test_DefaultAllowHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + if headersVal != "Origin,Accept,Content-Type,Authorization" { + t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) + } +} + +func Test_Preflight(t *testing.T) { + recorder := NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowMethods: []string{"PUT", "PATCH"}, + AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, + })) + + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + r, _ := http.NewRequest("OPTIONS", "/foo", nil) + r.Header.Add(headerRequestMethod, "PUT") + r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") + handler.ServeHTTP(recorder, r) + + headers := recorder.Header() + methodsVal := headers.Get(headerAllowMethods) + headersVal := headers.Get(headerAllowHeaders) + originVal := headers.Get(headerAllowOrigin) + + if methodsVal != "PUT,PATCH" { + t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) + } + + if !strings.Contains(headersVal, "X-whatever") { + t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) + } + + if !strings.Contains(headersVal, "x-casesensitive") { + t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) + } + + if originVal != "*" { + t.Errorf("Allow-Origin is expected to be *, found %v", originVal) + } + + if recorder.Code != http.StatusOK { + t.Errorf("Status code is expected to be 200, found %d", recorder.Code) + } +} + +func Benchmark_WithoutCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} + +func Benchmark_WithCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} diff --git a/pkg/server/web/filter/apiauth/apiauth.go b/pkg/server/web/filter/apiauth/apiauth.go index ba56030b..8944db63 100644 --- a/pkg/server/web/filter/apiauth/apiauth.go +++ b/pkg/server/web/filter/apiauth/apiauth.go @@ -83,11 +83,6 @@ func APIBasicAuth(appid, appkey string) web.FilterFunc { return APISecretAuth(ft, 300) } -// APIBasicAuth calls APIBasicAuth for previous callers -func APIBaiscAuth(appid, appkey string) web.FilterFunc { - return APIBasicAuth(appid, appkey) -} - // APISecretAuth uses AppIdToAppSecret verify and func APISecretAuth(f AppIDToAppSecret, timeout int) web.FilterFunc { return func(ctx *context.Context) { From f6c95ad5346e77ebf0ade03489e3080d62a76e0f Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 16:56:56 +0800 Subject: [PATCH 25/35] Adapter: swagger module --- pkg/adapter/swagger/swagger.go | 68 ++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 pkg/adapter/swagger/swagger.go diff --git a/pkg/adapter/swagger/swagger.go b/pkg/adapter/swagger/swagger.go new file mode 100644 index 00000000..214959d9 --- /dev/null +++ b/pkg/adapter/swagger/swagger.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +import ( + "github.com/astaxie/beego/pkg/server/web/swagger" +) + +// Swagger list the resource +type Swagger swagger.Swagger + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information swagger.Information + +// Contact information for the exposed API. +type Contact swagger.Contact + +// License information for the exposed API. +type License swagger.License + +// Item Describes the operations available on a single path. +type Item swagger.Item + +// Operation Describes a single API operation on a path. +type Operation swagger.Operation + +// Parameter Describes a single operation parameter. +type Parameter swagger.Parameter + +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// http://swagger.io/specification/#itemsObject +type ParameterItems swagger.ParameterItems + +// Schema Object allows the definition of input and output data types. +type Schema swagger.Schema + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie swagger.Propertie + +// Response as they are returned from executing this operation. +type Response swagger.Response + +// Security Allows the definition of a security scheme that can be used by the operations +type Security swagger.Security + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag swagger.Tag + +// ExternalDocs include Additional external documentation +type ExternalDocs swagger.ExternalDocs From 35f1bd211929cb32e9dccdc82420782d25a2804f Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 16:58:49 +0800 Subject: [PATCH 26/35] Adapter: testing --- pkg/adapter/testing/client.go | 50 +++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 pkg/adapter/testing/client.go diff --git a/pkg/adapter/testing/client.go b/pkg/adapter/testing/client.go new file mode 100644 index 00000000..688aa6f3 --- /dev/null +++ b/pkg/adapter/testing/client.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "github.com/astaxie/beego/pkg/client/httplib/testing" +) + +var port = "" +var baseURL = "http://localhost:" + +// TestHTTPRequest beego test request client +type TestHTTPRequest testing.TestHTTPRequest + +// Get returns test client in GET method +func Get(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Get(path)) +} + +// Post returns test client in POST method +func Post(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Post(path)) +} + +// Put returns test client in PUT method +func Put(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Put(path)) +} + +// Delete returns test client in DELETE method +func Delete(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Delete(path)) +} + +// Head returns test client in HEAD method +func Head(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Head(path)) +} From f4a43814bec6e005d92de5a91aaaa513482e0f9d Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 18:07:42 +0800 Subject: [PATCH 27/35] Adapter: utils --- pkg/adapter/cache/cache.go | 85 +++++++++ pkg/adapter/cache/cache_test.go | 191 +++++++++++++++++++++ pkg/adapter/utils/caller.go | 24 +++ pkg/adapter/utils/caller_test.go | 28 +++ pkg/adapter/utils/captcha/LICENSE | 19 ++ pkg/adapter/utils/captcha/README.md | 45 +++++ pkg/adapter/utils/captcha/captcha.go | 124 +++++++++++++ pkg/adapter/utils/captcha/image.go | 35 ++++ pkg/adapter/utils/captcha/image_test.go | 58 +++++++ pkg/adapter/utils/debug.go | 34 ++++ pkg/adapter/utils/debug_test.go | 46 +++++ pkg/adapter/utils/file.go | 47 +++++ pkg/adapter/utils/file_test.go | 75 ++++++++ pkg/adapter/utils/mail.go | 63 +++++++ pkg/adapter/utils/mail_test.go | 41 +++++ pkg/adapter/utils/pagination/controller.go | 26 +++ pkg/adapter/utils/pagination/doc.go | 58 +++++++ pkg/adapter/utils/pagination/paginator.go | 112 ++++++++++++ pkg/adapter/utils/rand.go | 24 +++ pkg/adapter/utils/rand_test.go | 33 ++++ pkg/adapter/utils/safemap.go | 58 +++++++ pkg/adapter/utils/safemap_test.go | 89 ++++++++++ pkg/adapter/utils/slice.go | 101 +++++++++++ pkg/adapter/utils/slice_test.go | 29 ++++ pkg/adapter/utils/utils.go | 10 ++ 25 files changed, 1455 insertions(+) create mode 100644 pkg/adapter/cache/cache.go create mode 100644 pkg/adapter/cache/cache_test.go create mode 100644 pkg/adapter/utils/caller.go create mode 100644 pkg/adapter/utils/caller_test.go create mode 100644 pkg/adapter/utils/captcha/LICENSE create mode 100644 pkg/adapter/utils/captcha/README.md create mode 100644 pkg/adapter/utils/captcha/captcha.go create mode 100644 pkg/adapter/utils/captcha/image.go create mode 100644 pkg/adapter/utils/captcha/image_test.go create mode 100644 pkg/adapter/utils/debug.go create mode 100644 pkg/adapter/utils/debug_test.go create mode 100644 pkg/adapter/utils/file.go create mode 100644 pkg/adapter/utils/file_test.go create mode 100644 pkg/adapter/utils/mail.go create mode 100644 pkg/adapter/utils/mail_test.go create mode 100644 pkg/adapter/utils/pagination/controller.go create mode 100644 pkg/adapter/utils/pagination/doc.go create mode 100644 pkg/adapter/utils/pagination/paginator.go create mode 100644 pkg/adapter/utils/rand.go create mode 100644 pkg/adapter/utils/rand_test.go create mode 100644 pkg/adapter/utils/safemap.go create mode 100644 pkg/adapter/utils/safemap_test.go create mode 100644 pkg/adapter/utils/slice.go create mode 100644 pkg/adapter/utils/slice_test.go create mode 100644 pkg/adapter/utils/utils.go diff --git a/pkg/adapter/cache/cache.go b/pkg/adapter/cache/cache.go new file mode 100644 index 00000000..21bb9141 --- /dev/null +++ b/pkg/adapter/cache/cache.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + + "github.com/astaxie/beego/pkg/client/cache" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache cache.Cache + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/adapter/cache/cache_test.go b/pkg/adapter/cache/cache_test.go new file mode 100644 index 00000000..470c0a43 --- /dev/null +++ b/pkg/adapter/cache/cache_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(30 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test GetMulti + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } + + os.RemoveAll("cache") +} diff --git a/pkg/adapter/utils/caller.go b/pkg/adapter/utils/caller.go new file mode 100644 index 00000000..d4fcc456 --- /dev/null +++ b/pkg/adapter/utils/caller.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetFuncName get function name +func GetFuncName(i interface{}) string { + return utils.GetFuncName(i) +} diff --git a/pkg/adapter/utils/caller_test.go b/pkg/adapter/utils/caller_test.go new file mode 100644 index 00000000..0675f0aa --- /dev/null +++ b/pkg/adapter/utils/caller_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "strings" + "testing" +) + +func TestGetFuncName(t *testing.T) { + name := GetFuncName(TestGetFuncName) + t.Log(name) + if !strings.HasSuffix(name, ".TestGetFuncName") { + t.Error("get func name error") + } +} diff --git a/pkg/adapter/utils/captcha/LICENSE b/pkg/adapter/utils/captcha/LICENSE new file mode 100644 index 00000000..0ad73ae0 --- /dev/null +++ b/pkg/adapter/utils/captcha/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Dmitry Chestnykh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/adapter/utils/captcha/README.md b/pkg/adapter/utils/captcha/README.md new file mode 100644 index 00000000..dbc2026b --- /dev/null +++ b/pkg/adapter/utils/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplName = "index.tpl" +} + +func (this *MainController) Post() { + this.TplName = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +
+ {{create_captcha}} + +
+``` diff --git a/pkg/adapter/utils/captcha/captcha.go b/pkg/adapter/utils/captcha/captcha.go new file mode 100644 index 00000000..faadc8bf --- /dev/null +++ b/pkg/adapter/utils/captcha/captcha.go @@ -0,0 +1,124 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package captcha implements generation and verification of image CAPTCHAs. +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplName = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
+// {{create_captcha}} +// +//
+// ``` +package captcha + +import ( + "html/template" + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/captcha" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/pkg/adapter/context" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + // default captcha attributes + challengeNums = 6 + expiration = 600 * time.Second + fieldIDName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + defaultURLPrefix = "/captcha/" +) + +// Captcha struct +type Captcha captcha.Captcha + +// Handler beego filter handler for serve captcha image +func (c *Captcha) Handler(ctx *context.Context) { + (*captcha.Captcha)(c).Handler((*beecontext.Context)(ctx)) +} + +// CreateCaptchaHTML template func for output html +func (c *Captcha) CreateCaptchaHTML() template.HTML { + return (*captcha.Captcha)(c).CreateCaptchaHTML() +} + +// CreateCaptcha create a new captcha id +func (c *Captcha) CreateCaptcha() (string, error) { + return (*captcha.Captcha)(c).CreateCaptcha() +} + +// VerifyReq verify from a request +func (c *Captcha) VerifyReq(req *http.Request) bool { + return (*captcha.Captcha)(c).VerifyReq(req) +} + +// Verify direct verify id and challenge string +func (c *Captcha) Verify(id string, challenge string) (success bool) { + return (*captcha.Captcha)(c).Verify(id, challenge) +} + +// NewCaptcha create a new captcha.Captcha +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewCaptcha(urlPrefix, store)) +} + +// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a template func for output html +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewWithFilter(urlPrefix, store)) +} diff --git a/pkg/adapter/utils/captcha/image.go b/pkg/adapter/utils/captcha/image.go new file mode 100644 index 00000000..9979db84 --- /dev/null +++ b/pkg/adapter/utils/captcha/image.go @@ -0,0 +1,35 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "io" + + "github.com/astaxie/beego/pkg/server/web/captcha" +) + +// Image struct +type Image captcha.Image + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + return (*Image)(captcha.NewImage(digits, width, height)) +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + return (*captcha.Image)(m).WriteTo(w) +} diff --git a/pkg/adapter/utils/captcha/image_test.go b/pkg/adapter/utils/captcha/image_test.go new file mode 100644 index 00000000..bce2134a --- /dev/null +++ b/pkg/adapter/utils/captcha/image_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/utils" +) + +const ( + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/pkg/adapter/utils/debug.go b/pkg/adapter/utils/debug.go new file mode 100644 index 00000000..d39f3d3e --- /dev/null +++ b/pkg/adapter/utils/debug.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Display print the data in console +func Display(data ...interface{}) { + utils.Display(data...) +} + +// GetDisplayString return data print string +func GetDisplayString(data ...interface{}) string { + return utils.GetDisplayString(data...) +} + +// Stack get stack bytes +func Stack(skip int, indent string) []byte { + return utils.Stack(skip, indent) +} diff --git a/pkg/adapter/utils/debug_test.go b/pkg/adapter/utils/debug_test.go new file mode 100644 index 00000000..efb8924e --- /dev/null +++ b/pkg/adapter/utils/debug_test.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +type mytype struct { + next *mytype + prev *mytype +} + +func TestPrint(t *testing.T) { + Display("v1", 1, "v2", 2, "v3", 3) +} + +func TestPrintPoint(t *testing.T) { + var v1 = new(mytype) + var v2 = new(mytype) + + v1.prev = nil + v1.next = v2 + + v2.prev = v1 + v2.next = nil + + Display("v1", v1, "v2", v2) +} + +func TestPrintString(t *testing.T) { + str := GetDisplayString("v1", 1, "v2", 2) + println(str) +} diff --git a/pkg/adapter/utils/file.go b/pkg/adapter/utils/file.go new file mode 100644 index 00000000..8979389e --- /dev/null +++ b/pkg/adapter/utils/file.go @@ -0,0 +1,47 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// SelfPath gets compiled executable file absolute path +func SelfPath() string { + return utils.SelfPath() +} + +// SelfDir gets compiled executable file directory +func SelfDir() string { + return utils.SelfDir() +} + +// FileExists reports whether the named file or directory exists. +func FileExists(name string) bool { + return utils.FileExists(name) +} + +// SearchFile Search a file in paths. +// this is often used in search config file in /etc ~/ +func SearchFile(filename string, paths ...string) (fullpath string, err error) { + return utils.SearchFile(filename, paths...) +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(patten string, filename string) (lines []string, err error) { + return utils.GrepFile(patten, filename) +} diff --git a/pkg/adapter/utils/file_test.go b/pkg/adapter/utils/file_test.go new file mode 100644 index 00000000..b2644157 --- /dev/null +++ b/pkg/adapter/utils/file_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "path/filepath" + "reflect" + "testing" +) + +var noExistedFile = "/tmp/not_existed_file" + +func TestSelfPath(t *testing.T) { + path := SelfPath() + if path == "" { + t.Error("path cannot be empty") + } + t.Logf("SelfPath: %s", path) +} + +func TestSelfDir(t *testing.T) { + dir := SelfDir() + t.Logf("SelfDir: %s", dir) +} + +func TestFileExists(t *testing.T) { + if !FileExists("./file.go") { + t.Errorf("./file.go should exists, but it didn't") + } + + if FileExists(noExistedFile) { + t.Errorf("Weird, how could this file exists: %s", noExistedFile) + } +} + +func TestSearchFile(t *testing.T) { + path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) + if err != nil { + t.Error(err) + } + t.Log(path) + + _, err = SearchFile(noExistedFile, ".") + if err == nil { + t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) + } +} + +func TestGrepFile(t *testing.T) { + _, err := GrepFile("", noExistedFile) + if err == nil { + t.Error("expect file-not-existed error, but got nothing") + } + + path := filepath.Join(".", "testdata", "grepe.test") + lines, err := GrepFile(`^\s*[^#]+`, path) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(lines, []string{"hello", "world"}) { + t.Errorf("expect [hello world], but receive %v", lines) + } +} diff --git a/pkg/adapter/utils/mail.go b/pkg/adapter/utils/mail.go new file mode 100644 index 00000000..35a58756 --- /dev/null +++ b/pkg/adapter/utils/mail.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "io" + + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Email is the type used for email messages +type Email utils.Email + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment utils.Attachment + +// NewEMail create new Email struct with config json. +// config json is followed from Email struct fields. +func NewEMail(config string) *Email { + return (*Email)(utils.NewEMail(config)) +} + +// Bytes Make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + return (*utils.Email)(e).Bytes() +} + +// AttachFile Add attach file to the send mail +func (e *Email) AttachFile(args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).AttachFile(args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Attach is used to attach content from an io.Reader to the email. +// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. +func (e *Email) Attach(r io.Reader, filename string, args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).Attach(r, filename, args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Send will send out the mail +func (e *Email) Send() error { + return (*utils.Email)(e).Send() +} diff --git a/pkg/adapter/utils/mail_test.go b/pkg/adapter/utils/mail_test.go new file mode 100644 index 00000000..c38356a2 --- /dev/null +++ b/pkg/adapter/utils/mail_test.go @@ -0,0 +1,41 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

Fancy Html is supported, too!

" + mail.AttachFile("/Users/astaxie/github/beego/beego.go") + mail.Send() +} diff --git a/pkg/adapter/utils/pagination/controller.go b/pkg/adapter/utils/pagination/controller.go new file mode 100644 index 00000000..a908d8b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/pagination" +) + +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). +func SetPaginator(ctx *context.Context, per int, nums int64) (paginator *Paginator) { + return (*Paginator)(pagination.SetPaginator((*beecontext.Context)(ctx), per, nums)) +} diff --git a/pkg/adapter/utils/pagination/doc.go b/pkg/adapter/utils/pagination/doc.go new file mode 100644 index 00000000..9abc6d78 --- /dev/null +++ b/pkg/adapter/utils/pagination/doc.go @@ -0,0 +1,58 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/pkg/adapter/utils/pagination/paginator.go b/pkg/adapter/utils/pagination/paginator.go new file mode 100644 index 00000000..4bd4a1b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/paginator.go @@ -0,0 +1,112 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" +) + +// Paginator within the state of a http request. +type Paginator pagination.Paginator + +// PageNums Returns the total number of pages. +func (p *Paginator) PageNums() int { + return (*pagination.Paginator)(p).PageNums() +} + +// Nums Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return (*pagination.Paginator)(p).Nums() +} + +// SetNums Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + (*pagination.Paginator)(p).SetNums(nums) +} + +// Page Returns the current page. +func (p *Paginator) Page() int { + return (*pagination.Paginator)(p).Page() +} + +// Pages Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + return (*pagination.Paginator)(p).Pages() +} + +// PageLink Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + return (*pagination.Paginator)(p).PageLink(page) +} + +// PageLinkPrev Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + return (*pagination.Paginator)(p).PageLinkPrev() +} + +// PageLinkNext Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + return (*pagination.Paginator)(p).PageLinkNext() +} + +// PageLinkFirst Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return (*pagination.Paginator)(p).PageLinkFirst() +} + +// PageLinkLast Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return (*pagination.Paginator)(p).PageLinkLast() +} + +// HasPrev Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return (*pagination.Paginator)(p).HasPrev() +} + +// HasNext Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return (*pagination.Paginator)(p).HasNext() +} + +// IsActive Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return (*pagination.Paginator)(p).IsActive(page) +} + +// Offset Returns the current offset. +func (p *Paginator) Offset() int { + return (*pagination.Paginator)(p).Offset() +} + +// HasPages Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return (*pagination.Paginator)(p).HasPages() +} + +// NewPaginator Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + return (*Paginator)(pagination.NewPaginator(req, per, nums)) +} diff --git a/pkg/adapter/utils/rand.go b/pkg/adapter/utils/rand.go new file mode 100644 index 00000000..ae415cf3 --- /dev/null +++ b/pkg/adapter/utils/rand.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + return utils.RandomCreateBytes(n, alphabets...) +} diff --git a/pkg/adapter/utils/rand_test.go b/pkg/adapter/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/pkg/adapter/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +} diff --git a/pkg/adapter/utils/safemap.go b/pkg/adapter/utils/safemap.go new file mode 100644 index 00000000..13e7bb46 --- /dev/null +++ b/pkg/adapter/utils/safemap.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// BeeMap is a map with lock +type BeeMap utils.BeeMap + +// NewBeeMap return new safemap +func NewBeeMap() *BeeMap { + return (*BeeMap)(utils.NewBeeMap()) +} + +// Get from maps return the k's value +func (m *BeeMap) Get(k interface{}) interface{} { + return (*utils.BeeMap)(m).Get(k) +} + +// Set Maps the given key and value. Returns false +// if the key is already in the map and changes nothing. +func (m *BeeMap) Set(k interface{}, v interface{}) bool { + return (*utils.BeeMap)(m).Set(k, v) +} + +// Check Returns true if k is exist in the map. +func (m *BeeMap) Check(k interface{}) bool { + return (*utils.BeeMap)(m).Check(k) +} + +// Delete the given key and value. +func (m *BeeMap) Delete(k interface{}) { + (*utils.BeeMap)(m).Delete(k) +} + +// Items returns all items in safemap. +func (m *BeeMap) Items() map[interface{}]interface{} { + return (*utils.BeeMap)(m).Items() +} + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + return (*utils.BeeMap)(m).Count() +} diff --git a/pkg/adapter/utils/safemap_test.go b/pkg/adapter/utils/safemap_test.go new file mode 100644 index 00000000..65085195 --- /dev/null +++ b/pkg/adapter/utils/safemap_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +var safeMap *BeeMap + +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + safeMap = NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestReSet(t *testing.T) { + safeMap := NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } + // set diff value + if ok := safeMap.Set("astaxie", -1); !ok { + t.Error("expected", true, "got", false) + } + + // set same value + if ok := safeMap.Set("astaxie", -1); ok { + t.Error("expected", false, "got", true) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestItems(t *testing.T) { + safeMap := NewBeeMap() + safeMap.Set("astaxie", "hello") + for k, v := range safeMap.Items() { + key := k.(string) + value := v.(string) + if key != "astaxie" { + t.Error("expected the key should be astaxie") + } + if value != "hello" { + t.Error("expected the value should be hello") + } + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) + } +} diff --git a/pkg/adapter/utils/slice.go b/pkg/adapter/utils/slice.go new file mode 100644 index 00000000..24d19ad2 --- /dev/null +++ b/pkg/adapter/utils/slice.go @@ -0,0 +1,101 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +type reducetype func(interface{}) interface{} +type filtertype func(interface{}) bool + +// InSlice checks given string in string slice or not. +func InSlice(v string, sl []string) bool { + return utils.InSlice(v, sl) +} + +// InSliceIface checks given interface in interface slice. +func InSliceIface(v interface{}, sl []interface{}) bool { + return utils.InSliceIface(v, sl) +} + +// SliceRandList generate an int slice from min to max. +func SliceRandList(min, max int) []int { + return utils.SliceRandList(min, max) +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + return utils.SliceMerge(slice1, slice2) +} + +// SliceReduce generates a new slice after parsing every value by reduce function +func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { + return utils.SliceReduce(slice, func(i interface{}) interface{} { + return a(i) + }) +} + +// SliceRand returns random one from slice. +func SliceRand(a []interface{}) (b interface{}) { + return utils.SliceRand(a) +} + +// SliceSum sums all values in int64 slice. +func SliceSum(intslice []int64) (sum int64) { + return utils.SliceSum(intslice) +} + +// SliceFilter generates a new slice after filter function. +func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { + return utils.SliceFilter(slice, func(i interface{}) bool { + return a(i) + }) +} + +// SliceDiff returns diff slice of slice1 - slice2. +func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceDiff(slice1, slice2) +} + +// SliceIntersect returns slice that are present in all the slice1 and slice2. +func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceIntersect(slice1, slice2) +} + +// SliceChunk separates one slice to some sized slice. +func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { + return utils.SliceChunk(slice, size) +} + +// SliceRange generates a new slice from begin to end with step duration of int64 number. +func SliceRange(start, end, step int64) (intslice []int64) { + return utils.SliceRange(start, end, step) +} + +// SlicePad prepends size number of val into slice. +func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { + return utils.SlicePad(slice, size, val) +} + +// SliceUnique cleans repeated values in slice. +func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { + return utils.SliceUnique(slice) +} + +// SliceShuffle shuffles a slice. +func SliceShuffle(slice []interface{}) []interface{} { + return utils.SliceShuffle(slice) +} diff --git a/pkg/adapter/utils/slice_test.go b/pkg/adapter/utils/slice_test.go new file mode 100644 index 00000000..142dec96 --- /dev/null +++ b/pkg/adapter/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInSlice(t *testing.T) { + sl := []string{"A", "b"} + if !InSlice("A", sl) { + t.Error("should be true") + } + if InSlice("B", sl) { + t.Error("should be false") + } +} diff --git a/pkg/adapter/utils/utils.go b/pkg/adapter/utils/utils.go new file mode 100644 index 00000000..1f3bcd31 --- /dev/null +++ b/pkg/adapter/utils/utils.go @@ -0,0 +1,10 @@ +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + return utils.GetGOPATHs() +} From 5b3dd7e50f4fde914c2a919ce35f77b4d5c19fe0 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 6 Sep 2020 13:33:43 +0800 Subject: [PATCH 28/35] Adapter: orm --- pkg/adapter/orm/cmd.go | 28 ++ pkg/adapter/orm/db.go | 24 + pkg/adapter/orm/db_alias.go | 124 +++++ pkg/adapter/orm/models.go | 25 + pkg/adapter/orm/models_boot.go | 40 ++ pkg/adapter/orm/models_fields.go | 625 ++++++++++++++++++++++++ pkg/adapter/orm/orm.go | 314 ++++++++++++ pkg/adapter/orm/orm_conds.go | 83 ++++ pkg/adapter/orm/orm_log.go | 32 ++ pkg/adapter/orm/orm_queryset.go | 32 ++ pkg/adapter/orm/qb.go | 27 + pkg/adapter/orm/qb_mysql.go | 150 ++++++ pkg/adapter/orm/qb_tidb.go | 147 ++++++ pkg/adapter/orm/query_setter_adapter.go | 34 ++ pkg/adapter/orm/types.go | 150 ++++++ pkg/adapter/orm/utils.go | 286 +++++++++++ pkg/adapter/orm/utils_test.go | 70 +++ pkg/client/orm/db_alias.go | 60 ++- pkg/client/orm/orm.go | 6 +- 19 files changed, 2227 insertions(+), 30 deletions(-) create mode 100644 pkg/adapter/orm/cmd.go create mode 100644 pkg/adapter/orm/db.go create mode 100644 pkg/adapter/orm/db_alias.go create mode 100644 pkg/adapter/orm/models.go create mode 100644 pkg/adapter/orm/models_boot.go create mode 100644 pkg/adapter/orm/models_fields.go create mode 100644 pkg/adapter/orm/orm.go create mode 100644 pkg/adapter/orm/orm_conds.go create mode 100644 pkg/adapter/orm/orm_log.go create mode 100644 pkg/adapter/orm/orm_queryset.go create mode 100644 pkg/adapter/orm/qb.go create mode 100644 pkg/adapter/orm/qb_mysql.go create mode 100644 pkg/adapter/orm/qb_tidb.go create mode 100644 pkg/adapter/orm/query_setter_adapter.go create mode 100644 pkg/adapter/orm/types.go create mode 100644 pkg/adapter/orm/utils.go create mode 100644 pkg/adapter/orm/utils_test.go diff --git a/pkg/adapter/orm/cmd.go b/pkg/adapter/orm/cmd.go new file mode 100644 index 00000000..6fee237c --- /dev/null +++ b/pkg/adapter/orm/cmd.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RunCommand listen for orm command and then run it if command arguments passed. +func RunCommand() { + orm.RunCommand() +} + +func RunSyncdb(name string, force bool, verbose bool) error { + return orm.RunSyncdb(name, force, verbose) +} diff --git a/pkg/adapter/orm/db.go b/pkg/adapter/orm/db.go new file mode 100644 index 00000000..74bca8c0 --- /dev/null +++ b/pkg/adapter/orm/db.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +var ( + // ErrMissPK missing pk error + ErrMissPK = orm.ErrMissPK +) diff --git a/pkg/adapter/orm/db_alias.go b/pkg/adapter/orm/db_alias.go new file mode 100644 index 00000000..2ecc80e5 --- /dev/null +++ b/pkg/adapter/orm/db_alias.go @@ -0,0 +1,124 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "time" + + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// DriverType database driver constant int. +type DriverType orm.DriverType + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL = orm.DRMySQL + DRSqlite = orm.DRSqlite // sqlite + DROracle = orm.DROracle // oracle + DRPostgres = orm.DRPostgres // pgsql + DRTiDB = orm.DRTiDB // TiDB +) + +type DB orm.DB + +func (d *DB) Begin() (*sql.Tx, error) { + return (*orm.DB)(d).Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return (*orm.DB)(d).BeginTx(ctx, opts) +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return (*orm.DB)(d).Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return (*orm.DB)(d).PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).Exec(query, args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).ExecContext(ctx, query, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).Query(query, args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).QueryContext(ctx, query, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRow(query, args) +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRowContext(ctx, query, args...) +} + +// AddAliasWthDB add a aliasName for the drivename +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + return orm.AddAliasWthDB(aliasName, driverName, db) +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + opts := make([]utils.KV, 0, 2) + if len(params) > 0 { + opts = append(opts, hints.MaxIdleConnections(params[0])) + } + + if len(params) > 1 { + opts = append(opts, hints.MaxOpenConnections(params[1])) + } + return orm.RegisterDataBase(aliasName, driverName, dataSource, opts...) +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + return orm.RegisterDriver(driverName, orm.DriverType(typ)) +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) error { + return orm.SetDataBaseTZ(aliasName, tz) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + orm.SetMaxIdleConns(aliasName, maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + orm.SetMaxOpenConns(aliasName, maxOpenConns) +} + +// GetDB Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + return orm.GetDB(aliasNames...) +} diff --git a/pkg/adapter/orm/models.go b/pkg/adapter/orm/models.go new file mode 100644 index 00000000..3215f5b5 --- /dev/null +++ b/pkg/adapter/orm/models.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + orm.ResetModelCache() +} diff --git a/pkg/adapter/orm/models_boot.go b/pkg/adapter/orm/models_boot.go new file mode 100644 index 00000000..8888ef65 --- /dev/null +++ b/pkg/adapter/orm/models_boot.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + orm.RegisterModel(models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + orm.RegisterModelWithPrefix(prefix, models) +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + orm.RegisterModelWithSuffix(suffix, models...) +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + orm.BootStrap() +} diff --git a/pkg/adapter/orm/models_fields.go b/pkg/adapter/orm/models_fields.go new file mode 100644 index 00000000..666a97dc --- /dev/null +++ b/pkg/adapter/orm/models_fields.go @@ -0,0 +1,625 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Define the Type enum +const ( + TypeBooleanField = orm.TypeBooleanField + TypeVarCharField = orm.TypeVarCharField + TypeCharField = orm.TypeCharField + TypeTextField = orm.TypeTextField + TypeTimeField = orm.TypeTimeField + TypeDateField = orm.TypeDateField + TypeDateTimeField = orm.TypeDateTimeField + TypeBitField = orm.TypeBitField + TypeSmallIntegerField = orm.TypeSmallIntegerField + TypeIntegerField = orm.TypeIntegerField + TypeBigIntegerField = orm.TypeBigIntegerField + TypePositiveBitField = orm.TypePositiveBitField + TypePositiveSmallIntegerField = orm.TypePositiveSmallIntegerField + TypePositiveIntegerField = orm.TypePositiveIntegerField + TypePositiveBigIntegerField = orm.TypePositiveBigIntegerField + TypeFloatField = orm.TypeFloatField + TypeDecimalField = orm.TypeDecimalField + TypeJSONField = orm.TypeJSONField + TypeJsonbField = orm.TypeJsonbField + RelForeignKey = orm.RelForeignKey + RelOneToOne = orm.RelOneToOne + RelManyToMany = orm.RelManyToMany + RelReverseOne = orm.RelReverseOne + RelReverseMany = orm.RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = orm.IsIntegerField + IsPositiveIntegerField = orm.IsPositiveIntegerField + IsRelField = orm.IsRelField + IsFieldType = orm.IsFieldType +) + +// BooleanField A true/false field. +type BooleanField orm.BooleanField + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return orm.BooleanField(e).Value() +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + (*orm.BooleanField)(e).Set(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return (*orm.BooleanField)(e).String() +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return (*orm.BooleanField)(e).FieldType() +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + return (*orm.BooleanField)(e).SetRaw(value) +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return (*orm.BooleanField)(e).RawValue() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField orm.CharField + +// Value return the CharField's Value +func (e CharField) Value() string { + return orm.CharField(e).Value() +} + +// Set CharField value +func (e *CharField) Set(d string) { + (*orm.CharField)(e).Set(d) +} + +// String return the CharField +func (e *CharField) String() string { + return (*orm.CharField)(e).String() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return (*orm.CharField)(e).FieldType() +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + return (*orm.CharField)(e).SetRaw(value) +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return (*orm.CharField)(e).RawValue() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField orm.TimeField + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return orm.TimeField(e).Value() +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + (*orm.TimeField)(e).Set(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return (*orm.TimeField)(e).String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return (*orm.TimeField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + return (*orm.TimeField)(e).SetRaw(value) +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return (*orm.TimeField)(e).RawValue() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField orm.DateField + +// Value return the time.Time +func (e DateField) Value() time.Time { + return orm.DateField(e).Value() +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + (*orm.DateField)(e).Set(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return (*orm.DateField)(e).String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return (*orm.DateField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + return (*orm.DateField)(e).SetRaw(value) +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return (*orm.DateField)(e).RawValue() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField orm.DateTimeField + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return orm.DateTimeField(e).Value() +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + (*orm.DateTimeField)(e).Set(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return (*orm.DateTimeField)(e).String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return (*orm.DateTimeField)(e).FieldType() +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + return (*orm.DateTimeField)(e).SetRaw(value) +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return (*orm.DateTimeField)(e).RawValue() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField orm.FloatField + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return orm.FloatField(e).Value() +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + (*orm.FloatField)(e).Set(d) +} + +// String return the string +func (e *FloatField) String() string { + return (*orm.FloatField)(e).String() +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return (*orm.FloatField)(e).FieldType() +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + return (*orm.FloatField)(e).SetRaw(value) +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return (*orm.FloatField)(e).RawValue() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField orm.SmallIntegerField + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return orm.SmallIntegerField(e).Value() +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + (*orm.SmallIntegerField)(e).Set(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return (*orm.SmallIntegerField)(e).String() +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return (*orm.SmallIntegerField)(e).FieldType() +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + return (*orm.SmallIntegerField)(e).SetRaw(value) +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return (*orm.SmallIntegerField)(e).RawValue() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField orm.IntegerField + +// Value return the int32 +func (e IntegerField) Value() int32 { + return orm.IntegerField(e).Value() +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + (*orm.IntegerField)(e).Set(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return (*orm.IntegerField)(e).String() +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return (*orm.IntegerField)(e).FieldType() +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + return (*orm.IntegerField)(e).SetRaw(value) +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return (*orm.IntegerField)(e).RawValue() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField orm.BigIntegerField + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return orm.BigIntegerField(e).Value() +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + (*orm.BigIntegerField)(e).Set(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return (*orm.BigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return (*orm.BigIntegerField)(e).FieldType() +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + return (*orm.BigIntegerField)(e).SetRaw(value) +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return (*orm.BigIntegerField)(e).RawValue() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField orm.PositiveSmallIntegerField + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return orm.PositiveSmallIntegerField(e).Value() +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + (*orm.PositiveSmallIntegerField)(e).Set(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return (*orm.PositiveSmallIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return (*orm.PositiveSmallIntegerField)(e).FieldType() +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveSmallIntegerField)(e).SetRaw(value) +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return (*orm.PositiveSmallIntegerField)(e).RawValue() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField orm.PositiveIntegerField + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return orm.PositiveIntegerField(e).Value() +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + (*orm.PositiveIntegerField)(e).Set(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return (*orm.PositiveIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return (*orm.PositiveIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveIntegerField)(e).SetRaw(value) +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return (*orm.PositiveIntegerField)(e).RawValue() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField orm.PositiveBigIntegerField + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return orm.PositiveBigIntegerField(e).Value() +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + (*orm.PositiveBigIntegerField)(e).Set(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return (*orm.PositiveBigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return (*orm.PositiveBigIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveBigIntegerField)(e).SetRaw(value) +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return (*orm.PositiveBigIntegerField)(e).RawValue() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField orm.TextField + +// Value return TextField value +func (e TextField) Value() string { + return orm.TextField(e).Value() +} + +// Set the TextField value +func (e *TextField) Set(d string) { + (*orm.TextField)(e).Set(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return (*orm.TextField)(e).String() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return (*orm.TextField)(e).FieldType() +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + return (*orm.TextField)(e).SetRaw(value) +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return (*orm.TextField)(e).RawValue() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField orm.JSONField + +// Value return JSONField value +func (j JSONField) Value() string { + return orm.JSONField(j).Value() +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + (*orm.JSONField)(j).Set(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return (*orm.JSONField)(j).String() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return (*orm.JSONField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + return (*orm.JSONField)(j).SetRaw(value) +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return (*orm.JSONField)(j).RawValue() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField orm.JsonbField + +// Value return JsonbField value +func (j JsonbField) Value() string { + return orm.JsonbField(j).Value() +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + (*orm.JsonbField)(j).Set(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return (*orm.JsonbField)(j).String() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return (*orm.JsonbField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + return (*orm.JsonbField)(j).SetRaw(value) +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return (*orm.JsonbField)(j).RawValue() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/pkg/adapter/orm/orm.go b/pkg/adapter/orm/orm.go new file mode 100644 index 00000000..f8463ea2 --- /dev/null +++ b/pkg/adapter/orm/orm.go @@ -0,0 +1,314 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/astaxie/beego/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.Name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +// +// more docs: http://beego.me/docs/mvc/model/overview.md +package orm + +import ( + "context" + "database/sql" + "errors" + + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = orm.Debug + DebugLog = orm.DebugLog + DefaultRowsLimit = orm.DefaultRowsLimit + DefaultRelsDepth = orm.DefaultRelsDepth + DefaultTimeLoc = orm.DefaultTimeLoc + ErrTxHasBegan = errors.New(" transaction already begin") + ErrTxDone = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") +) + +type ormer struct { + delegate orm.Ormer + txDelegate orm.TxOrmer + isTx bool +} + +var _ Ormer = new(ormer) + +// read data to model +func (o *ormer) Read(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.Read(md, cols...) + } + return o.delegate.Read(md, cols...) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *ormer) ReadForUpdate(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.ReadForUpdate(md, cols...) + } + return o.delegate.ReadForUpdate(md, cols...) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *ormer) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + if o.isTx { + return o.txDelegate.ReadOrCreate(md, col1, cols...) + } + return o.delegate.ReadOrCreate(md, col1, cols...) +} + +// insert model data to database +func (o *ormer) Insert(md interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.Insert(md) + } + return o.delegate.Insert(md) +} + +// insert some models to database +func (o *ormer) InsertMulti(bulk int, mds interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.InsertMulti(bulk, mds) + } + return o.delegate.InsertMulti(bulk, mds) +} + +// InsertOrUpdate data to database +func (o *ormer) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + if o.isTx { + return o.txDelegate.InsertOrUpdate(md, colConflitAndArgs...) + } + return o.delegate.InsertOrUpdate(md, colConflitAndArgs...) +} + +// update model to database. +// cols set the columns those want to update. +func (o *ormer) Update(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Update(md, cols...) + } + return o.delegate.Update(md, cols...) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *ormer) Delete(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Delete(md, cols...) + } + return o.delegate.Delete(md, cols...) +} + +// create a models to models queryer +func (o *ormer) QueryM2M(md interface{}, name string) QueryM2Mer { + if o.isTx { + return o.txDelegate.QueryM2M(md, name) + } + return o.delegate.QueryM2M(md, name) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *ormer) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + kvs := make([]utils.KV, 0, 4) + for i, arg := range args { + switch i { + case 0: + if v, ok := arg.(bool); ok { + if v { + kvs = append(kvs, hints.DefaultRelDepth()) + } + } else if v, ok := arg.(int); ok { + kvs = append(kvs, hints.RelDepth(v)) + } + case 1: + kvs = append(kvs, hints.Limit(orm.ToInt64(arg))) + case 2: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + case 3: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + } + } + if o.isTx { + return o.txDelegate.LoadRelated(md, name, kvs...) + } + return o.delegate.LoadRelated(md, name, kvs...) +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *ormer) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + if o.isTx { + return o.txDelegate.QueryTable(ptrStructOrTableName) + } + return o.delegate.QueryTable(ptrStructOrTableName) +} + +// switch to another registered database driver by given name. +func (o *ormer) Using(name string) error { + if o.isTx { + return ErrTxHasBegan + } + o.delegate = orm.NewOrmUsingDB(name) + return nil +} + +// begin transaction +func (o *ormer) Begin() error { + if o.isTx { + return ErrTxHasBegan + } + return o.BeginTx(context.Background(), nil) +} + +func (o *ormer) BeginTx(ctx context.Context, opts *sql.TxOptions) error { + if o.isTx { + return ErrTxHasBegan + } + txOrmer, err := o.delegate.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + o.txDelegate = txOrmer + o.isTx = true + return nil +} + +// commit transaction +func (o *ormer) Commit() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Commit() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// rollback transaction +func (o *ormer) Rollback() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Rollback() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// return a raw query seter for raw sql string. +func (o *ormer) Raw(query string, args ...interface{}) RawSeter { + if o.isTx { + return o.txDelegate.Raw(query, args...) + } + return o.delegate.Raw(query, args...) +} + +// return current using database Driver +func (o *ormer) Driver() Driver { + if o.isTx { + return o.txDelegate.Driver() + } + return o.delegate.Driver() +} + +// return sql.DBStats for current database +func (o *ormer) DBStats() *sql.DBStats { + if o.isTx { + return o.txDelegate.DBStats() + } + return o.delegate.DBStats() +} + +// NewOrm create new orm +func NewOrm() Ormer { + o := orm.NewOrm() + return &ormer{ + delegate: o, + } +} + +// NewOrmWithDB create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + o, err := orm.NewOrmWithDB(driverName, aliasName, db) + if err != nil { + return nil, err + } + return &ormer{ + delegate: o, + }, nil +} diff --git a/pkg/adapter/orm/orm_conds.go b/pkg/adapter/orm/orm_conds.go new file mode 100644 index 00000000..986b4858 --- /dev/null +++ b/pkg/adapter/orm/orm_conds.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ExprSep define the expression separation +const ( + ExprSep = "__" +) + +// Condition struct. +// work for WHERE conditions. +type Condition orm.Condition + +// NewCondition return new condition struct +func NewCondition() *Condition { + return (*Condition)(orm.NewCondition()) +} + +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + return (*Condition)((orm.Condition)(c).Raw(expr, sql)) +} + +// And add expression to condition +func (c Condition) And(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).And(expr, args...)) +} + +// AndNot add NOT expression to condition +func (c Condition) AndNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).AndNot(expr, args...)) +} + +// AndCond combine a condition to current condition +func (c *Condition) AndCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndCond((*orm.Condition)(cond))) +} + +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndNotCond((*orm.Condition)(cond))) +} + +// Or add OR expression to condition +func (c Condition) Or(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).Or(expr, args...)) +} + +// OrNot add OR NOT expression to condition +func (c Condition) OrNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).OrNot(expr, args...)) +} + +// OrCond combine a OR condition to current condition +func (c *Condition) OrCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrCond((*orm.Condition)(cond))) +} + +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrNotCond((*orm.Condition)(cond))) +} + +// IsEmpty check the condition arguments are empty or not. +func (c *Condition) IsEmpty() bool { + return (*orm.Condition)(c).IsEmpty() +} diff --git a/pkg/adapter/orm/orm_log.go b/pkg/adapter/orm/orm_log.go new file mode 100644 index 00000000..6b2b4a9b --- /dev/null +++ b/pkg/adapter/orm/orm_log.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "io" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Log implement the log.Logger +type Log orm.Log + +// costomer log func +var LogFunc = orm.LogFunc + +// NewLog set io.Writer to create a Logger. +func NewLog(out io.Writer) *Log { + return (*Log)(orm.NewLog(out)) +} diff --git a/pkg/adapter/orm/orm_queryset.go b/pkg/adapter/orm/orm_queryset.go new file mode 100644 index 00000000..5f211644 --- /dev/null +++ b/pkg/adapter/orm/orm_queryset.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// define Col operations +const ( + ColAdd = orm.ColAdd + ColMinus = orm.ColMinus + ColMultiply = orm.ColMultiply + ColExcept = orm.ColExcept + ColBitAnd = orm.ColBitAnd + ColBitRShift = orm.ColBitRShift + ColBitLShift = orm.ColBitLShift + ColBitXOR = orm.ColBitXOR + ColBitOr = orm.ColBitOr +) diff --git a/pkg/adapter/orm/qb.go b/pkg/adapter/orm/qb.go new file mode 100644 index 00000000..90b97797 --- /dev/null +++ b/pkg/adapter/orm/qb.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// QueryBuilder is the Query builder interface +type QueryBuilder orm.QueryBuilder + +// NewQueryBuilder return the QueryBuilder +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + return orm.NewQueryBuilder(driver) +} diff --git a/pkg/adapter/orm/qb_mysql.go b/pkg/adapter/orm/qb_mysql.go new file mode 100644 index 00000000..9566068f --- /dev/null +++ b/pkg/adapter/orm/qb_mysql.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// CommaSpace is the separation +const CommaSpace = orm.CommaSpace + +// MySQLQueryBuilder is the SQL build +type MySQLQueryBuilder orm.MySQLQueryBuilder + +// Select will join the fields +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.MySQLQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *MySQLQueryBuilder) String() string { + return (*orm.MySQLQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/qb_tidb.go b/pkg/adapter/orm/qb_tidb.go new file mode 100644 index 00000000..05c91a26 --- /dev/null +++ b/pkg/adapter/orm/qb_tidb.go @@ -0,0 +1,147 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder orm.TiDBQueryBuilder + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.TiDBQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return (*orm.TiDBQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/query_setter_adapter.go b/pkg/adapter/orm/query_setter_adapter.go new file mode 100644 index 00000000..cc24ef6b --- /dev/null +++ b/pkg/adapter/orm/query_setter_adapter.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +type baseQuerySetter struct { +} + +func (b *baseQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} diff --git a/pkg/adapter/orm/types.go b/pkg/adapter/orm/types.go new file mode 100644 index 00000000..3372e301 --- /dev/null +++ b/pkg/adapter/orm/types.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Params stores the Params +type Params orm.Params + +// ParamsList stores paramslist +type ParamsList orm.ParamsList + +// Driver define database driver +type Driver orm.Driver + +// Fielder define field info +type Fielder orm.Fielder + +// Ormer define the orm interface +type Ormer interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will set user's pk field + Insert(interface{}) (int64, error) + // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}, cols ...string) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + // args[0] bool true useDefaultRelsDepth ; false depth 0 + // args[0] int loadRelationDepth + // args[1] int limit default limit 1000 + // args[2] int offset default offset 0 + // args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() + Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error + // commit transaction + Commit() error + // rollback transaction + Rollback() error + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + Driver() Driver + DBStats() *sql.DBStats +} + +// Inserter insert prepared statement +type Inserter orm.Inserter + +// QuerySeter query seter +type QuerySeter orm.QuerySeter + +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer orm.QueryM2Mer + +// RawPreparer raw query statement +type RawPreparer orm.RawPreparer + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter orm.RawSeter diff --git a/pkg/adapter/orm/utils.go b/pkg/adapter/orm/utils.go new file mode 100644 index 00000000..16d0e4e5 --- /dev/null +++ b/pkg/adapter/orm/utils.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + +// StrTo is the target string +type StrTo orm.StrTo + +// Set string +func (f *StrTo) Set(v string) { + (*orm.StrTo)(f).Set(v) +} + +// Clear string +func (f *StrTo) Clear() { + (*orm.StrTo)(f).Clear() +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return orm.StrTo(f).Exist() +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return orm.StrTo(f).Bool() +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + return orm.StrTo(f).Float32() +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return orm.StrTo(f).Float64() +} + +// Int string to int +func (f StrTo) Int() (int, error) { + return orm.StrTo(f).Int() +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + return orm.StrTo(f).Int8() +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + return orm.StrTo(f).Int16() +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + return orm.StrTo(f).Int32() +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + return orm.StrTo(f).Int64() +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + return orm.StrTo(f).Uint() +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + return orm.StrTo(f).Uint8() +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + return orm.StrTo(f).Uint16() +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + return orm.StrTo(f).Uint32() +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + return orm.StrTo(f).Uint64() +} + +// String string to string +func (f StrTo) String() string { + return orm.StrTo(f).String() +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// snake string, XxYy to xx_yy , XxYY to xx_y_y +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + +// camel string, xx_yy to XxYy +func camelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data[:]) +} + +type argString []string + +// get string by index from string slice +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// parse time to string with location +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// get pointer indirect type +func indirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return indirectType(v.Elem()) + default: + return v + } +} diff --git a/pkg/adapter/orm/utils_test.go b/pkg/adapter/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/pkg/adapter/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/pkg/client/orm/db_alias.go b/pkg/client/orm/db_alias.go index 8a5cfb10..c72f29c4 100644 --- a/pkg/client/orm/db_alias.go +++ b/pkg/client/orm/db_alias.go @@ -400,22 +400,47 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV detectTZ(al) kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxIdleConns(al, m) - } + al.SetMaxIdleConns(value.(int)) }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxOpenConns(al, m) - } + al.SetMaxOpenConns(value.(int)) }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { - if m, ok := value.(time.Duration); ok { - SetConnMaxLifetime(al, m) - } + al.SetConnMaxLifetime(value.(time.Duration)) }) return al, nil } +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxOpenConns) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxIdleConns(maxIdleConns int) { + al.MaxIdleConns = maxIdleConns + al.DB.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxOpenConns(maxOpenConns int) { + al.MaxOpenConns = maxOpenConns + al.DB.DB.SetMaxOpenConns(maxOpenConns) +} + +func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) { + al.ConnMaxLifetime = lifeTime + al.DB.DB.SetConnMaxLifetime(lifeTime) +} + // AddAliasWthDB add a aliasName for the drivename func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) error { _, err := addAliasWthDB(aliasName, driverName, db, params...) @@ -476,23 +501,6 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { return nil } -// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(al *alias, maxIdleConns int) { - al.MaxIdleConns = maxIdleConns - al.DB.DB.SetMaxIdleConns(maxIdleConns) -} - -// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(al *alias, maxOpenConns int) { - al.MaxOpenConns = maxOpenConns - al.DB.DB.SetMaxOpenConns(maxOpenConns) -} - -func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { - al.ConnMaxLifetime = lifeTime - al.DB.DB.SetConnMaxLifetime(lifeTime) -} - // GetDB Get *sql.DB from registered database by db alias name. // Use "default" as alias name if you not set. func GetDB(aliasNames ...string) (*sql.DB, error) { diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index 634b1892..95bbcb31 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -311,9 +311,7 @@ func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (in return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { - _, fi, ind, qseter := o.queryRelated(md, name) - - qs := qseter.(*querySet) + _, fi, ind, qs := o.queryRelated(md, name) var relDepth int var limit, offset int64 @@ -377,7 +375,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s } // get QuerySeter for related models to md model -func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { +func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) From 3acda41bc7be4494c9d925b06ab954683bbc01a1 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 6 Sep 2020 15:21:07 +0800 Subject: [PATCH 29/35] Fix UT --- pkg/adapter/app.go | 3 +- pkg/adapter/cache/cache_test.go | 191 ------------------- pkg/adapter/flash.go | 2 +- pkg/adapter/metric/prometheus_test.go | 2 +- pkg/adapter/plugins/cors/cors_test.go | 253 -------------------------- pkg/adapter/router.go | 3 + pkg/adapter/utils/file_test.go | 75 -------- pkg/server/web/app.go | 8 +- pkg/server/web/filter.go | 11 +- pkg/server/web/namespace.go | 2 +- pkg/server/web/router_test.go | 28 +-- pkg/task/task_test.go | 14 +- 12 files changed, 41 insertions(+), 551 deletions(-) delete mode 100644 pkg/adapter/cache/cache_test.go delete mode 100644 pkg/adapter/plugins/cors/cors_test.go delete mode 100644 pkg/adapter/utils/file_test.go diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go index 64280a7b..c1046c79 100644 --- a/pkg/adapter/app.go +++ b/pkg/adapter/app.go @@ -255,7 +255,8 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + opts := oldToNewFilterOpts(params) return (*App)(web.InsertFilter(pattern, pos, func(ctx *context.Context) { filter((*context2.Context)(ctx)) - }, params...)) + }, opts...)) } diff --git a/pkg/adapter/cache/cache_test.go b/pkg/adapter/cache/cache_test.go deleted file mode 100644 index 470c0a43..00000000 --- a/pkg/adapter/cache/cache_test.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "os" - "sync" - "testing" - "time" -) - -func TestCacheIncr(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - //timeoutDuration := 10 * time.Second - - bm.Put("edwardhey", 0, time.Second*20) - wg := sync.WaitGroup{} - wg.Add(10) - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - bm.Incr("edwardhey") - }() - } - wg.Wait() - if bm.Get("edwardhey").(int) != 10 { - t.Error("Incr err") - } -} - -func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - time.Sleep(30 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } -} - -func TestFileCache(t *testing.T) { - bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - - os.RemoveAll("cache") -} diff --git a/pkg/adapter/flash.go b/pkg/adapter/flash.go index e5e1c187..02e75ed6 100644 --- a/pkg/adapter/flash.go +++ b/pkg/adapter/flash.go @@ -28,7 +28,7 @@ func NewFlash() *FlashData { // Set message to flash func (fd *FlashData) Set(key string, msg string, args ...interface{}) { - (*web.FlashData)(fd).Set(key, msg, args) + (*web.FlashData)(fd).Set(key, msg, args...) } // Success writes success message to flash. diff --git a/pkg/adapter/metric/prometheus_test.go b/pkg/adapter/metric/prometheus_test.go index d82a6dec..87286e02 100644 --- a/pkg/adapter/metric/prometheus_test.go +++ b/pkg/adapter/metric/prometheus_test.go @@ -22,7 +22,7 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/adapter/context" ) func TestPrometheusMiddleWare(t *testing.T) { diff --git a/pkg/adapter/plugins/cors/cors_test.go b/pkg/adapter/plugins/cors/cors_test.go deleted file mode 100644 index 34039143..00000000 --- a/pkg/adapter/plugins/cors/cors_test.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cors - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header -type HTTPHeaderGuardRecorder struct { - *httptest.ResponseRecorder - savedHeaderMap http.Header -} - -// NewRecorder return HttpHeaderGuardRecorder -func NewRecorder() *HTTPHeaderGuardRecorder { - return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} -} - -func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { - gr.ResponseRecorder.WriteHeader(code) - gr.savedHeaderMap = gr.ResponseRecorder.Header() -} - -func (gr *HTTPHeaderGuardRecorder) Header() http.Header { - if gr.savedHeaderMap != nil { - // headers were written. clone so we don't get updates - clone := make(http.Header) - for k, v := range gr.savedHeaderMap { - clone[k] = v - } - return clone - } - return gr.ResponseRecorder.Header() -} - -func Test_AllowAll(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { - t.Errorf("Allow-Origin header should be *") - } -} - -func Test_AllowRegexMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://bar.foo.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != origin { - t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) - } -} - -func Test_AllowRegexNoMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://ww.foo.com.evil.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != "" { - t.Errorf("Allow-Origin header should not exist, found %v", headerValue) - } -} - -func Test_OtherHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - ExposeHeaders: []string{"Content-Length", "Hello"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) - methodsVal := recorder.HeaderMap.Get(headerAllowMethods) - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) - maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) - - if credentialsVal != "true" { - t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) - } - - if methodsVal != "PATCH,GET" { - t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) - } - - if headersVal != "Origin,X-whatever" { - t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) - } - - if exposedHeadersVal != "Content-Length,Hello" { - t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) - } - - if maxAgeVal != "300" { - t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) - } -} - -func Test_DefaultAllowHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - if headersVal != "Origin,Accept,Content-Type,Authorization" { - t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) - } -} - -func Test_Preflight(t *testing.T) { - recorder := NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowMethods: []string{"PUT", "PATCH"}, - AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, - })) - - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - r, _ := http.NewRequest("OPTIONS", "/foo", nil) - r.Header.Add(headerRequestMethod, "PUT") - r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") - handler.ServeHTTP(recorder, r) - - headers := recorder.Header() - methodsVal := headers.Get(headerAllowMethods) - headersVal := headers.Get(headerAllowHeaders) - originVal := headers.Get(headerAllowOrigin) - - if methodsVal != "PUT,PATCH" { - t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) - } - - if !strings.Contains(headersVal, "X-whatever") { - t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) - } - - if !strings.Contains(headersVal, "x-casesensitive") { - t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) - } - - if originVal != "*" { - t.Errorf("Allow-Origin is expected to be *, found %v", originVal) - } - - if recorder.Code != http.StatusOK { - t.Errorf("Status code is expected to be 200, found %d", recorder.Code) - } -} - -func Benchmark_WithoutCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} - -func Benchmark_WithCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} diff --git a/pkg/adapter/router.go b/pkg/adapter/router.go index 5a36fbee..8e8d9fdb 100644 --- a/pkg/adapter/router.go +++ b/pkg/adapter/router.go @@ -249,6 +249,9 @@ func oldToNewFilterOpts(params []bool) []web.FilterOpt { opts := make([]web.FilterOpt, 0, 4) if len(params) > 0 { opts = append(opts, web.WithReturnOnOutput(params[0])) + } else { + // the default value should be true + opts = append(opts, web.WithReturnOnOutput(true)) } if len(params) > 1 { opts = append(opts, web.WithResetParams(params[1])) diff --git a/pkg/adapter/utils/file_test.go b/pkg/adapter/utils/file_test.go deleted file mode 100644 index b2644157..00000000 --- a/pkg/adapter/utils/file_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "path/filepath" - "reflect" - "testing" -) - -var noExistedFile = "/tmp/not_existed_file" - -func TestSelfPath(t *testing.T) { - path := SelfPath() - if path == "" { - t.Error("path cannot be empty") - } - t.Logf("SelfPath: %s", path) -} - -func TestSelfDir(t *testing.T) { - dir := SelfDir() - t.Logf("SelfDir: %s", dir) -} - -func TestFileExists(t *testing.T) { - if !FileExists("./file.go") { - t.Errorf("./file.go should exists, but it didn't") - } - - if FileExists(noExistedFile) { - t.Errorf("Weird, how could this file exists: %s", noExistedFile) - } -} - -func TestSearchFile(t *testing.T) { - path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) - if err != nil { - t.Error(err) - } - t.Log(path) - - _, err = SearchFile(noExistedFile, ".") - if err == nil { - t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) - } -} - -func TestGrepFile(t *testing.T) { - _, err := GrepFile("", noExistedFile) - if err == nil { - t.Error("expect file-not-existed error, but got nothing") - } - - path := filepath.Join(".", "testdata", "grepe.test") - lines, err := GrepFile(`^\s*[^#]+`, path) - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(lines, []string{"hello", "world"}) { - t.Errorf("expect [hello world], but receive %v", lines) - } -} diff --git a/pkg/server/web/app.go b/pkg/server/web/app.go index ad3ff663..7511c7fe 100644 --- a/pkg/server/web/app.go +++ b/pkg/server/web/app.go @@ -492,15 +492,15 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // The pos means action constant including // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) +func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, opts...) return BeeApp } // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. // the filter's behavior is like stack -func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { - BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) +func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, opts...) return BeeApp } diff --git a/pkg/server/web/filter.go b/pkg/server/web/filter.go index e10faafc..9aab48d6 100644 --- a/pkg/server/web/filter.go +++ b/pkg/server/web/filter.go @@ -45,13 +45,14 @@ type FilterRouter struct { // 2. determining whether or not params need to be reset. func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, + tree: NewTree(), + pattern: pattern, + filterFunc: filter, } - fos := &filterOpts{} + fos := &filterOpts{ + returnOnOutput: true, + } for _, o := range opts { o(fos) diff --git a/pkg/server/web/namespace.go b/pkg/server/web/namespace.go index e59f38c5..a792aa60 100644 --- a/pkg/server/web/namespace.go +++ b/pkg/server/web/namespace.go @@ -91,7 +91,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { a = FinishRouter } for _, f := range filter { - n.handlers.InsertFilter("*", a, f) + n.handlers.InsertFilter("*", a, f, WithReturnOnOutput(true)) } return n } diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index 14ad1484..33b75703 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -423,7 +423,7 @@ func TestInsertFilter(t *testing.T) { testName := "TestInsertFilter" mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true)) if !mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing no variadic params should set returnOnOutput to true", @@ -436,7 +436,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(false)) if mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing false as 1st variadic param should set returnOnOutput to false", @@ -444,7 +444,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true), WithResetParams(true)) if !mux.filters[BeforeRouter][0].resetParams { t.Errorf( "%s: passing true as 2nd variadic param should set resetParams to true", @@ -461,7 +461,7 @@ func TestParamResetFilter(t *testing.T) { mux := NewControllerRegister() - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + mux.InsertFilter("*", BeforeExec, beegoResetParams, WithReturnOnOutput(true), WithResetParams(true)) mux.Get(route, beegoHandleResetParams) @@ -514,8 +514,8 @@ func TestFilterBeforeExec(t *testing.T) { url := "/beforeExec" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -542,7 +542,7 @@ func TestFilterAfterExec(t *testing.T) { mux := NewControllerRegister() mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) @@ -570,10 +570,10 @@ func TestFilterFinishRouter(t *testing.T) { url := "/finishRouter" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -604,7 +604,7 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { url := "/finishRouterMultiFirstOnly" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) mux.Get(url, beegoFilterFunc) @@ -631,8 +631,8 @@ func TestFilterFinishRouterMulti(t *testing.T) { url := "/finishRouterMulti" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 9f73ce46..488729dc 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -15,6 +15,7 @@ package task import ( + "context" "errors" "fmt" "sync" @@ -25,7 +26,10 @@ import ( ) func TestParse(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) err := tk.Run(nil) if err != nil { t.Fatal(err) @@ -39,9 +43,9 @@ func TestParse(t *testing.T) { func TestSpec(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(2) - tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) - tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) - tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func(ctx context.Context) error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func(ctx context.Context) error { fmt.Println("tk3"); wg.Done(); return nil }) AddTask("tk1", tk1) AddTask("tk2", tk2) @@ -58,7 +62,7 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 - task := func() error { + task := func(ctx context.Context) error { cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) From 6bf01eaeca8b0e8ef4cb9c35e5159a8ae55e9401 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 7 Sep 2020 20:36:54 +0800 Subject: [PATCH 30/35] Move pr 3784 here --- pkg/client/orm/orm_raw.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/client/orm/orm_raw.go b/pkg/client/orm/orm_raw.go index c2539147..e11e97fa 100644 --- a/pkg/client/orm/orm_raw.go +++ b/pkg/client/orm/orm_raw.go @@ -330,6 +330,8 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return err } + structTagMap := make(map[reflect.StructTag]map[string]string) + defer rows.Close() if rows.Next() { @@ -396,7 +398,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { recursiveSetField(f) } - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + // thanks @Gazeboxu. + tags := structTagMap[fe.Tag] + if tags == nil { + _, tags = parseStructTag(fe.Tag.Get(defaultStructTagName)) + structTagMap[fe.Tag] = tags + } var col string if col = tags["column"]; col == "" { col = nameStrategyMap[nameStrategy](fe.Name) From 0f50b07a20b25be899e573dbb6b50d464f9001d1 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 7 Sep 2020 21:40:20 +0800 Subject: [PATCH 31/35] allow users to ignore some table when run orm commands --- pkg/client/orm/cmd.go | 6 +++++ pkg/client/orm/models.go | 2 +- pkg/client/orm/models_utils.go | 12 ++++++++++ pkg/client/orm/models_utils_test.go | 35 +++++++++++++++++++++++++++++ pkg/client/orm/types.go | 5 +++++ 5 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 pkg/client/orm/models_utils_test.go diff --git a/pkg/client/orm/cmd.go b/pkg/client/orm/cmd.go index e03fc0ee..b0661971 100644 --- a/pkg/client/orm/cmd.go +++ b/pkg/client/orm/cmd.go @@ -142,6 +142,12 @@ func (d *commandSyncDb) Run() error { } for i, mi := range modelCache.allOrdered() { + + if !isApplicableTableForDB(mi.addrField, d.al.Name) { + fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name) + continue + } + if tables[mi.table] { if !d.noInfo { fmt.Printf("table `%s` already exists, skip\n", mi.table) diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index a7de10f7..24f564ab 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -414,7 +414,7 @@ func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { for _, mi := range modelCache.allOrdered() { queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) } - return queries,nil + return queries, nil } //getDbCreateSQL get database scheme creation sql queries diff --git a/pkg/client/orm/models_utils.go b/pkg/client/orm/models_utils.go index 6fca59a9..950ca243 100644 --- a/pkg/client/orm/models_utils.go +++ b/pkg/client/orm/models_utils.go @@ -107,6 +107,18 @@ func getTableUnique(val reflect.Value) [][]string { return nil } +// get whether the table needs to be created for the database alias +func isApplicableTableForDB(val reflect.Value, db string) bool { + fun := val.MethodByName("IsApplicableTableForDB") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) + if len(vals) > 0 && vals[0].Kind() == reflect.Bool { + return vals[0].Bool() + } + } + return true +} + // get snaked column name func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := col diff --git a/pkg/client/orm/models_utils_test.go b/pkg/client/orm/models_utils_test.go new file mode 100644 index 00000000..0a6995b3 --- /dev/null +++ b/pkg/client/orm/models_utils_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type NotApplicableModel struct { + Id int +} + +func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { + return db == "default" +} + +func Test_IsApplicableTableForDB(t *testing.T) { + assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) + assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) +} diff --git a/pkg/client/orm/types.go b/pkg/client/orm/types.go index eb34e759..b0c793b7 100644 --- a/pkg/client/orm/types.go +++ b/pkg/client/orm/types.go @@ -75,6 +75,11 @@ type TableUniqueI interface { TableUnique() [][]string } +// IsApplicableTableForDB if return false, we won't create table to this db +type IsApplicableTableForDB interface { + IsApplicableTableForDB(db string) bool +} + // Driver define database driver type Driver interface { Name() string From f580a714d5748d86d2c2ad6915030253162c2aa5 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 8 Sep 2020 20:47:39 +0800 Subject: [PATCH 32/35] Optimize orm by using BDOption rather than hints --- pkg/adapter/orm/db_alias.go | 8 ++- pkg/client/orm/db_alias.go | 77 +++++++++++++++++---------- pkg/client/orm/db_alias_test.go | 16 +++--- pkg/client/orm/hints/db_hints.go | 30 +---------- pkg/client/orm/hints/db_hints_test.go | 28 ---------- pkg/client/orm/models_test.go | 4 +- pkg/client/orm/orm.go | 2 +- 7 files changed, 63 insertions(+), 102 deletions(-) diff --git a/pkg/adapter/orm/db_alias.go b/pkg/adapter/orm/db_alias.go index 2ecc80e5..b1f1a724 100644 --- a/pkg/adapter/orm/db_alias.go +++ b/pkg/adapter/orm/db_alias.go @@ -20,8 +20,6 @@ import ( "time" "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // DriverType database driver constant int. @@ -86,13 +84,13 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - opts := make([]utils.KV, 0, 2) + opts := make([]orm.DBOption, 0, 2) if len(params) > 0 { - opts = append(opts, hints.MaxIdleConnections(params[0])) + opts = append(opts, orm.MaxIdleConnections(params[0])) } if len(params) > 1 { - opts = append(opts, hints.MaxOpenConnections(params[1])) + opts = append(opts, orm.MaxOpenConnections(params[1])) } return orm.RegisterDataBase(aliasName, driverName, dataSource, opts...) } diff --git a/pkg/client/orm/db_alias.go b/pkg/client/orm/db_alias.go index c72f29c4..29e0904c 100644 --- a/pkg/client/orm/db_alias.go +++ b/pkg/client/orm/db_alias.go @@ -21,9 +21,6 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/infrastructure/utils" - lru "github.com/hashicorp/golang-lru" ) @@ -278,6 +275,7 @@ type alias struct { MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration + StmtCacheSize int DB *DB DbBaser dbBaser TZ *time.Location @@ -340,7 +338,7 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) if _, ok := dataBaseCache.get(aliasName); ok { return nil, existErr @@ -358,32 +356,35 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) return al, nil } -func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) { - kvs := utils.NewKVs(params...) +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { + + al := &alias{} + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + } + + for _, p := range params { + p(al) + } var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int) - if maxStmtCacheSize > 0 { - _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if al.StmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(al.StmtCacheSize) if errC != nil { return nil, errC } else { stmtCache = _stmtCache - stmtCacheSize = maxStmtCacheSize + stmtCacheSize = al.StmtCacheSize } } - al := new(alias) al.Name = aliasName al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: stmtCache, - stmtDecoratorsLimit: stmtCacheSize, - } + al.DB.stmtDecorators = stmtCache + al.DB.stmtDecoratorsLimit = stmtCacheSize if dr, ok := drivers[driverName]; ok { al.DbBaser = dbBasers[dr] @@ -399,14 +400,6 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV detectTZ(al) - kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { - al.SetMaxIdleConns(value.(int)) - }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { - al.SetMaxOpenConns(value.(int)) - }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { - al.SetConnMaxLifetime(value.(time.Duration)) - }) - return al, nil } @@ -442,13 +435,13 @@ func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) { } // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) error { +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error { _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...utils.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error { var ( err error db *sql.DB @@ -561,3 +554,33 @@ func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { } return cache, nil } + +type DBOption func(al *alias) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(maxIdleConn int) DBOption { + return func(al *alias) { + al.SetMaxIdleConns(maxIdleConn) + } +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(maxOpenConn int) DBOption { + return func(al *alias) { + al.SetMaxOpenConns(maxOpenConn) + } +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) DBOption { + return func(al *alias) { + al.SetConnMaxLifetime(v) + } +} + +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) DBOption { + return func(al *alias) { + al.StmtCacheSize = v + } +} diff --git a/pkg/client/orm/db_alias_test.go b/pkg/client/orm/db_alias_test.go index 0043ba76..6275cb2a 100644 --- a/pkg/client/orm/db_alias_test.go +++ b/pkg/client/orm/db_alias_test.go @@ -18,16 +18,14 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/stretchr/testify/assert" ) func TestRegisterDataBase(t *testing.T) { err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, - hints.MaxIdleConnections(20), - hints.MaxOpenConnections(300), - hints.ConnMaxLifetime(time.Minute)) + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") @@ -39,7 +37,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -49,7 +47,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -59,7 +57,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -69,7 +67,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/client/orm/hints/db_hints.go b/pkg/client/orm/hints/db_hints.go index 4d199312..7340bd07 100644 --- a/pkg/client/orm/hints/db_hints.go +++ b/pkg/client/orm/hints/db_hints.go @@ -15,20 +15,12 @@ package hints import ( - "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" ) const ( - //db level - KeyMaxIdleConnections = iota - KeyMaxOpenConnections - KeyConnMaxLifetime - KeyMaxStmtCacheSize - //query level - KeyForceIndex + KeyForceIndex = iota KeyUseIndex KeyIgnoreIndex KeyForUpdate @@ -57,26 +49,6 @@ func (s *Hint) GetValue() interface{} { var _ utils.KV = new(Hint) -// MaxIdleConnections return a hint about MaxIdleConnections -func MaxIdleConnections(v int) *Hint { - return NewHint(KeyMaxIdleConnections, v) -} - -// MaxOpenConnections return a hint about MaxOpenConnections -func MaxOpenConnections(v int) *Hint { - return NewHint(KeyMaxOpenConnections, v) -} - -// ConnMaxLifetime return a hint about ConnMaxLifetime -func ConnMaxLifetime(v time.Duration) *Hint { - return NewHint(KeyConnMaxLifetime, v) -} - -// MaxStmtCacheSize return a hint about MaxStmtCacheSize -func MaxStmtCacheSize(v int) *Hint { - return NewHint(KeyMaxStmtCacheSize, v) -} - // ForceIndex return a hint about ForceIndex func ForceIndex(indexes ...string) *Hint { return NewHint(KeyForceIndex, indexes) diff --git a/pkg/client/orm/hints/db_hints_test.go b/pkg/client/orm/hints/db_hints_test.go index 4e962a8f..510f9f16 100644 --- a/pkg/client/orm/hints/db_hints_test.go +++ b/pkg/client/orm/hints/db_hints_test.go @@ -48,34 +48,6 @@ func TestNewHint_float(t *testing.T) { assert.Equal(t, hint.GetValue(), value) } -func TestMaxOpenConnections(t *testing.T) { - i := 887423 - hint := MaxOpenConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections) -} - -func TestConnMaxLifetime(t *testing.T) { - i := time.Hour - hint := ConnMaxLifetime(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime) -} - -func TestMaxIdleConnections(t *testing.T) { - i := 42316 - hint := MaxIdleConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections) -} - -func TestMaxStmtCacheSize(t *testing.T) { - i := 94157 - hint := MaxStmtCacheSize(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize) -} - func TestForceIndex(t *testing.T) { s := []string{`f_index1`, `f_index2`, `f_index3`} hint := ForceIndex(s...) diff --git a/pkg/client/orm/models_test.go b/pkg/client/orm/models_test.go index 81ba30df..f0044f6d 100644 --- a/pkg/client/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -22,8 +22,6 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -529,7 +527,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20)) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index 95bbcb31..bfb710d1 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -601,7 +601,7 @@ func NewOrmUsingDB(aliasName string) Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...utils.KV) (Ormer, error) { +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...DBOption) (Ormer, error) { al, err := newAliasWithDb(aliasName, driverName, db, params...) if err != nil { return nil, err From 8982f5d70236f6083740c3de66ae2a58607eb260 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Wed, 9 Sep 2020 00:23:57 +0100 Subject: [PATCH 33/35] Add unit tests for custom log formatter Also moved is Colorful check to WriteMsg function to make the interface for user's using the custom logging formatting simpler. The user does not have to check if the text is colorful now, the WriteMsg function handles it. --- pkg/logs/console.go | 10 ++---- .../logformattertest/log_formatter_test.go | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 pkg/logs/logformattertest/log_formatter_test.go diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 34114e4a..a3e5fb5a 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -58,10 +58,6 @@ type consoleWriter struct { func (c *consoleWriter) Format(lm *LogMsg) string { msg := lm.Msg - if c.Colorful { - msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) - } - h, _, _ := formatTimeHeader(lm.When) bytes := append(append(h, msg...), '\n') @@ -105,13 +101,13 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if lm.Level > c.Level { return nil } - // fmt.Printf("Formatted: %s\n\n", c.fmtter.Format(lm)) + + msg := "" + if c.Colorful { lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } - msg := "" - if c.customFormatter != nil { msg = c.customFormatter(lm) } else { diff --git a/pkg/logs/logformattertest/log_formatter_test.go b/pkg/logs/logformattertest/log_formatter_test.go new file mode 100644 index 00000000..2d99a8e6 --- /dev/null +++ b/pkg/logs/logformattertest/log_formatter_test.go @@ -0,0 +1,36 @@ +package logformattertest + +import ( + "fmt" + "testing" + + "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/logs" +) + +func customFormatter(lm *logs.LogMsg) string { + return fmt.Sprintf("[CUSTOM CONSOLE LOGGING] %s", lm.Msg) +} + +func globalFormatter(lm *logs.LogMsg) string { + return fmt.Sprintf("[GLOBAL] %s", lm.Msg) +} + +func TestCustomLoggingFormatter(t *testing.T) { + // beego.BConfig.Log.AccessLogs = true + + logs.SetLoggerWithOpts("console", []string{`{"color":true}`}, common.SimpleKV{Key: "formatter", Value: customFormatter}) + + // Message will be formatted by the customFormatter with colorful text set to true + logs.Informational("Test message") +} + +func TestGlobalLoggingFormatter(t *testing.T) { + logs.SetGlobalFormatter(globalFormatter) + + logs.SetLogger("console", `{"color":true}`) + + // Message will be formatted by globalFormatter + logs.Informational("Test message") + +} From 63cd8e4e15de50618bf0da81e59c0789864e4975 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 11 Sep 2020 21:10:12 +0800 Subject: [PATCH 34/35] refactor log module --- pkg/infrastructure/logs/alils/alils.go | 56 ++--- pkg/infrastructure/logs/conn.go | 52 ++--- pkg/infrastructure/logs/console.go | 63 +++-- pkg/infrastructure/logs/es/es.go | 65 +++--- pkg/infrastructure/logs/file.go | 65 +++--- pkg/infrastructure/logs/file_test.go | 5 + pkg/infrastructure/logs/formatter.go | 34 +++ pkg/infrastructure/logs/jianliao.go | 58 ++--- pkg/infrastructure/logs/log.go | 215 ++++++------------ pkg/infrastructure/logs/log_formatter_test.go | 35 --- pkg/infrastructure/logs/multifile.go | 43 ++-- pkg/infrastructure/logs/slack.go | 43 ++-- pkg/infrastructure/logs/smtp.go | 34 +-- 13 files changed, 346 insertions(+), 422 deletions(-) create mode 100644 pkg/infrastructure/logs/formatter.go delete mode 100644 pkg/infrastructure/logs/log_formatter_test.go diff --git a/pkg/infrastructure/logs/alils/alils.go b/pkg/infrastructure/logs/alils/alils.go index 03e97045..0689aae0 100644 --- a/pkg/infrastructure/logs/alils/alils.go +++ b/pkg/infrastructure/logs/alils/alils.go @@ -2,12 +2,14 @@ package alils import ( "encoding/json" + "fmt" "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/gogo/protobuf/proto" + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/logs" ) const ( @@ -28,40 +30,35 @@ type Config struct { Source string `json:"source"` Level int `json:"level"` FlushWhen int `json:"flush_when"` + Formatter string `json:"formatter"` } // aliLSWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex - customFormatter func(*logs.LogMsg) string + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex Config + formatter logs.LogFormatter } // NewAliLS creates a new Logger func NewAliLS() logs.Logger { alils := new(aliLSWriter) alils.Level = logs.LevelTrace + alils.formatter = alils return alils } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string, opts ...utils.KV) error { - - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := logs.GetFormatter(elem) - if err != nil { - return err - } - c.customFormatter = formatter - } +func (c *aliLSWriter) Init(config string) error { + err := json.Unmarshal([]byte(config), c) + if err != nil { + return err } - json.Unmarshal([]byte(jsonConfig), c) if c.FlushWhen > CacheSize { c.FlushWhen = CacheSize @@ -110,11 +107,23 @@ func (c *aliLSWriter) Init(jsonConfig string, opts ...utils.KV) error { c.lock = &sync.Mutex{} + if len(c.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) + } + c.formatter = fmtr + } + return nil } func (c *aliLSWriter) Format(lm *logs.LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() +} + +func (c *aliLSWriter) SetFormatter(f logs.LogFormatter) { + c.formatter = f } // WriteMsg writes a message in connection. @@ -145,11 +154,7 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { lg = c.group[0] } - if c.customFormatter != nil { - content = c.customFormatter(lm) - } else { - content = c.Format(lm) - } + content = c.formatter.Format(lm) c1 := &LogContent{ Key: proto.String("msg"), @@ -170,7 +175,6 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { if len(lg.Logs) >= c.FlushWhen { c.flush(lg) } - return nil } diff --git a/pkg/infrastructure/logs/conn.go b/pkg/infrastructure/logs/conn.go index f7d44d7f..1fd71be7 100644 --- a/pkg/infrastructure/logs/conn.go +++ b/pkg/infrastructure/logs/conn.go @@ -16,51 +16,55 @@ package logs import ( "encoding/json" + "fmt" "io" "net" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - customFormatter func(*LogMsg) string - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + formatter LogFormatter + Formatter string `json:"formatter"` + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. func NewConn() Logger { conn := new(connWriter) conn.Level = LevelTrace + conn.formatter = conn return conn } func (c *connWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() } // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string, opts ...utils.KV) error { - - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - c.customFormatter = formatter +func (c *connWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) } + c.formatter = fmtr } + return res +} - return json.Unmarshal([]byte(jsonConfig), c) +func (c *connWriter) SetFormatter(f LogFormatter) { + c.formatter = f } // WriteMsg writes message in connection. @@ -80,13 +84,7 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { defer c.innerWriter.Close() } - msg := "" - if c.customFormatter != nil { - msg = c.customFormatter(lm) - } else { - msg = c.Format(lm) - - } + msg := c.formatter.Format(lm) _, err := c.lg.writeln(msg) if err != nil { diff --git a/pkg/infrastructure/logs/console.go b/pkg/infrastructure/logs/console.go index 802d79f5..f99ef11b 100644 --- a/pkg/infrastructure/logs/console.go +++ b/pkg/infrastructure/logs/console.go @@ -16,11 +16,11 @@ package logs import ( "encoding/json" + "fmt" "os" "strings" - "github.com/astaxie/beego/pkg/infrastructure/utils" - + "github.com/pkg/errors" "github.com/shiena/ansicolor" ) @@ -49,20 +49,25 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - customFormatter func(*LogMsg) string - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + formatter LogFormatter + Formatter string `json:"formatter"` + Level int `json:"level"` + Colorful bool `json:"color"` // this filed is useful only when system's terminal supports color } func (c *consoleWriter) Format(lm *LogMsg) string { - msg := lm.Msg - + msg := lm.OldStyleFormat() + if c.Colorful { + msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) + } h, _, _ := formatTimeHeader(lm.When) bytes := append(append(h, msg...), '\n') - return string(bytes) +} +func (c *consoleWriter) SetFormatter(f LogFormatter) { + c.formatter = f } // NewConsole creates ConsoleWriter returning as LoggerInterface. @@ -72,28 +77,27 @@ func NewConsole() Logger { Level: LevelDebug, Colorful: true, } + cw.formatter = cw return cw } // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string, opts ...utils.KV) error { +func (c *consoleWriter) Init(config string) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - c.customFormatter = formatter - } - } - - if len(jsonConfig) == 0 { + if len(config) == 0 { return nil } - return json.Unmarshal([]byte(jsonConfig), c) + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) + } + c.formatter = fmtr + } + return res } // WriteMsg writes message in console. @@ -101,20 +105,7 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if lm.Level > c.Level { return nil } - - msg := "" - - if c.Colorful { - lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) - } - - if c.customFormatter != nil { - msg = c.customFormatter(lm) - } else { - msg = c.Format(lm) - - } - + msg := c.formatter.Format(lm) c.lg.writeln(msg) return nil } diff --git a/pkg/infrastructure/logs/es/es.go b/pkg/infrastructure/logs/es/es.go index 857a1a34..438a6da6 100644 --- a/pkg/infrastructure/logs/es/es.go +++ b/pkg/infrastructure/logs/es/es.go @@ -13,7 +13,6 @@ import ( "github.com/elastic/go-elasticsearch/v6/esapi" "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // NewES returns a LoggerInterface @@ -32,29 +31,34 @@ func NewES() logs.Logger { // import _ "github.com/astaxie/beego/logs/es" type esLogger struct { *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` - customFormatter func(*logs.LogMsg) string + DSN string `json:"dsn"` + Level int `json:"level"` + formatter logs.LogFormatter + Formatter string `json:"formatter"` } func (el *esLogger) Format(lm *logs.LogMsg) string { - return lm.Msg + + msg := lm.OldStyleFormat() + idx := LogDocument{ + Timestamp: lm.When.Format(time.RFC3339), + Msg: msg, + } + body, err := json.Marshal(idx) + if err != nil { + return msg + } + return string(body) +} + +func (el *esLogger) SetFormatter(f logs.LogFormatter) { + el.formatter = f } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonConfig string, opts ...utils.KV) error { +func (el *esLogger) Init(config string) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := logs.GetFormatter(elem) - if err != nil { - return err - } - el.customFormatter = formatter - } - } - - err := json.Unmarshal([]byte(jsonConfig), el) + err := json.Unmarshal([]byte(config), el) if err != nil { return err } @@ -73,6 +77,13 @@ func (el *esLogger) Init(jsonConfig string, opts ...utils.KV) error { } el.Client = conn } + if len(el.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(el.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", el.Formatter)) + } + el.formatter = fmtr + } return nil } @@ -82,28 +93,14 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { return nil } - msg := "" - if el.customFormatter != nil { - msg = el.customFormatter(lm) - } else { - msg = el.Format(lm) - } + msg := el.formatter.Format(lm) - idx := LogDocument{ - Timestamp: lm.When.Format(time.RFC3339), - Msg: msg, - } - - body, err := json.Marshal(idx) - if err != nil { - return err - } req := esapi.IndexRequest{ Index: fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()), DocumentType: "logs", - Body: strings.NewReader(string(body)), + Body: strings.NewReader(msg), } - _, err = req.Do(context.Background(), el.Client) + _, err := req.Do(context.Background(), el.Client) return err } diff --git a/pkg/infrastructure/logs/file.go b/pkg/infrastructure/logs/file.go index 0c96918c..b01be357 100644 --- a/pkg/infrastructure/logs/file.go +++ b/pkg/infrastructure/logs/file.go @@ -27,8 +27,6 @@ import ( "strings" "sync" "time" - - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // fileLogWriter implements LoggerInterface. @@ -62,8 +60,6 @@ type fileLogWriter struct { hourlyOpenDate int hourlyOpenTime time.Time - customFormatter func(*LogMsg) string - Rotate bool `json:"rotate"` Level int `json:"level"` @@ -73,6 +69,9 @@ type fileLogWriter struct { RotatePerm string `json:"rotateperm"` fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix + + formatter LogFormatter + Formatter string `json:"formatter"` } // newFileWriter creates a FileLogWriter returning as LoggerInterface. @@ -90,11 +89,19 @@ func newFileWriter() Logger { MaxFiles: 999, MaxSize: 1 << 28, } + w.formatter = w return w } func (w *fileLogWriter) Format(lm *LogMsg) string { - return lm.Msg + msg := lm.OldStyleFormat() + hd, _, _ := formatTimeHeader(lm.When) + msg = fmt.Sprintf("%s %s\n", string(hd), msg) + return msg +} + +func (w *fileLogWriter) SetFormatter(f LogFormatter) { + w.formatter = f } // Init file logger with json config. @@ -108,19 +115,9 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { +func (w *fileLogWriter) Init(config string) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - w.customFormatter = formatter - } - } - - err := json.Unmarshal([]byte(jsonConfig), w) + err := json.Unmarshal([]byte(config), w) if err != nil { return err } @@ -132,6 +129,14 @@ func (w *fileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { if w.suffix == "" { w.suffix = ".log" } + + if len(w.Formatter) > 0 { + fmtr, ok := GetFormatter(w.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", w.Formatter)) + } + w.formatter = fmtr + } err = w.startLogger() return err } @@ -149,13 +154,13 @@ func (w *fileLogWriter) startLogger() error { return w.initFd() } -func (w *fileLogWriter) needRotateDaily(size int, day int) bool { +func (w *fileLogWriter) needRotateDaily(day int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Daily && day != w.dailyOpenDate) } -func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { +func (w *fileLogWriter) needRotateHourly(hour int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Hourly && hour != w.hourlyOpenDate) @@ -167,31 +172,25 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { if lm.Level > w.Level { return nil } - hd, d, h := formatTimeHeader(lm.When) - msg := "" - if w.customFormatter != nil { - msg = w.customFormatter(lm) - } else { - msg = w.Format(lm) - } + _, d, h := formatTimeHeader(lm.When) - msg = fmt.Sprintf("%s %s\n", string(hd), msg) + msg := w.formatter.Format(lm) if w.Rotate { w.RLock() - if w.needRotateHourly(len(lm.Msg), h) { + if w.needRotateHourly(h) { w.RUnlock() w.Lock() - if w.needRotateHourly(len(lm.Msg), h) { + if w.needRotateHourly(h) { if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } } w.Unlock() - } else if w.needRotateDaily(len(lm.Msg), d) { + } else if w.needRotateDaily(d) { w.RUnlock() w.Lock() - if w.needRotateDaily(len(lm.Msg), d) { + if w.needRotateDaily(d) { if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -263,7 +262,7 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateDaily(0, time.Now().Day()) { + if w.needRotateDaily(time.Now().Day()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -278,7 +277,7 @@ func (w *fileLogWriter) hourlyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateHourly(0, time.Now().Hour()) { + if w.needRotateHourly(time.Now().Hour()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } diff --git a/pkg/infrastructure/logs/file_test.go b/pkg/infrastructure/logs/file_test.go index 7f2a3590..494d0a9e 100644 --- a/pkg/infrastructure/logs/file_test.go +++ b/pkg/infrastructure/logs/file_test.go @@ -268,6 +268,7 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { Perm: "0660", RotatePerm: "0440", } + fw.formatter = fw if daily { fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) @@ -308,6 +309,8 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { Perm: "0660", RotatePerm: "0440", } + fw.formatter = fw + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) fw.dailyOpenDate = fw.dailyOpenTime.Day() @@ -340,6 +343,8 @@ func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { Perm: "0660", RotatePerm: "0440", } + + fw.formatter = fw fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() diff --git a/pkg/infrastructure/logs/formatter.go b/pkg/infrastructure/logs/formatter.go new file mode 100644 index 00000000..b2599f2d --- /dev/null +++ b/pkg/infrastructure/logs/formatter.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +var formatterMap = make(map[string]LogFormatter, 4) + +type LogFormatter interface { + Format(lm *LogMsg) string +} + +// RegisterFormatter register an formatter. Usually you should use this to extend your custom formatter +// for example: +// RegisterFormatter("my-fmt", &MyFormatter{}) +// logs.SetFormatter(Console, `{"formatter": "my-fmt"}`) +func RegisterFormatter(name string, fmtr LogFormatter) { + formatterMap[name] = fmtr +} + +func GetFormatter(name string) (LogFormatter, bool) { + res, ok := formatterMap[name] + return res, ok +} diff --git a/pkg/infrastructure/logs/jianliao.go b/pkg/infrastructure/logs/jianliao.go index 88750125..9757a7d5 100644 --- a/pkg/infrastructure/logs/jianliao.go +++ b/pkg/infrastructure/logs/jianliao.go @@ -6,42 +6,49 @@ import ( "net/http" "net/url" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` - customFormatter func(*LogMsg) string + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` + + formatter LogFormatter + Formatter string `json:"formatter"` } // newJLWriter creates jiaoliao writer. func newJLWriter() Logger { - return &JLWriter{Level: LevelTrace} + res := &JLWriter{Level: LevelTrace} + res.formatter = res + return res } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonConfig string, opts ...utils.KV) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - s.customFormatter = formatter - } - } +func (s *JLWriter) Init(config string) error { - return json.Unmarshal([]byte(jsonConfig), s) + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) + } + s.formatter = fmtr + } + return res } func (s *JLWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() +} + +func (s *JLWriter) SetFormatter(f LogFormatter) { + s.formatter = f } // WriteMsg writes message in smtp writer. @@ -51,14 +58,7 @@ func (s *JLWriter) WriteMsg(lm *LogMsg) error { return nil } - text := "" - - if s.customFormatter != nil { - text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.customFormatter(lm)) - } else { - text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) - - } + text := s.formatter.Format(lm) form := url.Values{} form.Add("authorName", s.AuthorName) diff --git a/pkg/infrastructure/logs/log.go b/pkg/infrastructure/logs/log.go index 2d400eba..480cecab 100644 --- a/pkg/infrastructure/logs/log.go +++ b/pkg/infrastructure/logs/log.go @@ -38,13 +38,12 @@ import ( "log" "os" "path" - "reflect" "runtime" "strings" "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // RFC5424 log message levels. @@ -87,11 +86,11 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { - Init(config string, opts ...utils.KV) error + Init(config string) error WriteMsg(lm *LogMsg) error - Format(lm *LogMsg) string Destroy() Flush() + SetFormatter(f LogFormatter) } var adapters = make(map[string]newLoggerFunc) @@ -118,7 +117,6 @@ type BeeLogger struct { init bool enableFuncCallDepth bool loggerFuncCallDepth int - globalFormatter func(*LogMsg) string enableFullFilePath bool asynchronous bool prefix string @@ -127,6 +125,7 @@ type BeeLogger struct { signalChan chan string wg sync.WaitGroup outputs []*nameLogger + globalFormatter string } const defaultAsyncMsgLen = 1e3 @@ -137,15 +136,15 @@ type nameLogger struct { } type LogMsg struct { - Level int - Msg string - When time.Time - FilePath string - LineNumber int -} - -type LogFormatter interface { - Format(lm *LogMsg) string + Level int + Msg string + When time.Time + FilePath string + LineNumber int + Args []interface{} + Prefix string + enableFullFilePath bool + enableFuncCallDepth bool } var logMsgPool *sync.Pool @@ -188,8 +187,25 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { return bl } -func Format(lm *LogMsg) string { - return lm.Msg +// OldStyleFormat you should never invoke this +func (lm *LogMsg) OldStyleFormat() string { + msg := lm.Msg + + if len(lm.Args) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, lm.Args...) + } + + msg = lm.Prefix + " " + msg + + if lm.enableFuncCallDepth { + if !lm.enableFullFilePath { + _, lm.FilePath = path.Split(lm.FilePath) + } + msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, msg) + } + + msg = levelPrefix[lm.Level] + " " + msg + return msg } // SetLogger provides a given logger adapter into BeeLogger with config string. @@ -208,16 +224,18 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } lg := logAdapter() - var err error // Global formatter overrides the default set formatter - // but not adapter specific formatters set with logs.SetLoggerWithOpts() - if bl.globalFormatter != nil { - err = lg.Init(config, &utils.SimpleKV{Key: "formatter", Value: bl.globalFormatter}) - } else { - err = lg.Init(config) + if len(bl.globalFormatter) > 0 { + fmtr, ok := GetFormatter(bl.globalFormatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", bl.globalFormatter)) + } + lg.SetFormatter(fmtr) } + err := lg.Init(config) + if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -287,46 +305,34 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { return 0, err } -func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { +func (bl *BeeLogger) writeMsg(lm *LogMsg) error { if !bl.init { bl.lock.Lock() bl.setLogger(AdapterConsole) bl.lock.Unlock() } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) - } - - lm.Msg = bl.prefix + " " + lm.Msg - var ( file string line int ok bool ) - if bl.enableFuncCallDepth { - _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) - if !ok { - file = "???" - line = 0 - } - - if !bl.enableFullFilePath { - _, file = path.Split(file) - } - lm.FilePath = file - lm.LineNumber = line - lm.Msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, lm.Msg) + _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) + if !ok { + file = "???" + line = 0 } + lm.FilePath = file + lm.LineNumber = line + + lm.enableFullFilePath = bl.enableFullFilePath + lm.enableFuncCallDepth = bl.enableFuncCallDepth // set level info in front of filename info if lm.Level == levelLoggerImpl { // set to emergency to ensure all log will be print out correctly lm.Level = LevelEmergency - } else { - lm.Msg = levelPrefix[lm.Level] + " " + lm.Msg } if bl.asynchronous { @@ -334,6 +340,10 @@ func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { logM.Level = lm.Level logM.Msg = lm.Msg logM.When = lm.When + logM.Args = lm.Args + logM.FilePath = lm.FilePath + logM.LineNumber = lm.LineNumber + logM.Prefix = lm.Prefix if bl.outputs != nil { bl.msgChan <- lm } else { @@ -404,84 +414,14 @@ func (bl *BeeLogger) startLogger() { } } -// Get the formatter from the opts common.SimpleKV structure -// Looks for a key: "formatter" with value: func(*LogMsg) string -func GetFormatter(opts utils.KV) (func(*LogMsg) string, error) { - if strings.ToLower(opts.GetKey().(string)) == "formatter" { - formatterInterface := reflect.ValueOf(opts.GetValue()).Interface() - formatterFunc := formatterInterface.(func(*LogMsg) string) - return formatterFunc, nil - } - - return nil, fmt.Errorf("no \"formatter\" key given in simpleKV") -} - -// SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON -// such as: {"interval":360} -func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error { - config := append(configs, "{}")[0] - for _, l := range bl.outputs { - if l.name == adapterName { - return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) - } - } - - logAdapter, ok := adapters[adapterName] - if !ok { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) - } - - if opts.GetKey() == nil { - return fmt.Errorf("No SimpleKV struct set for %s log adapter", adapterName) - } - - lg := logAdapter() - err := lg.Init(config, opts) - if err != nil { - fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) - return err - } - - bl.outputs = append(bl.outputs, &nameLogger{ - name: adapterName, - Logger: lg, - }) - - return nil -} - -// SetLogger provides a given logger adapter into BeeLogger with config string. -func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - if !bl.init { - bl.outputs = []*nameLogger{} - bl.init = true - } - return bl.setLoggerWithOpts(adapterName, opts, configs...) -} - -// SetLoggerWIthOpts sets a given log adapter with a custom log adapter. -// Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc} -// where FormatFunc has the signature func(*LogMsg) string -// func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { -func SetLoggerWithOpts(adapter string, config []string, opts utils.KV) error { - err := beeLogger.SetLoggerWithOpts(adapter, opts, config...) - if err != nil { - log.Fatal(err) - } - return nil - -} - -func (bl *BeeLogger) setGlobalFormatter(fmtter func(*LogMsg) string) error { +func (bl *BeeLogger) setGlobalFormatter(fmtter string) error { bl.globalFormatter = fmtter return nil } // SetGlobalFormatter sets the global formatter for all log adapters -// This overrides and other individually set adapter -func SetGlobalFormatter(fmtter func(*LogMsg) string) error { +// don't forget to register the formatter by invoking RegisterFormatter +func SetGlobalFormatter(fmtter string) error { return beeLogger.setGlobalFormatter(fmtter) } @@ -513,11 +453,8 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) { Level: LevelAlert, Msg: format, When: time.Now(), + Args: v, } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) - } - bl.writeMsg(lm) } @@ -530,9 +467,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) { Level: LevelCritical, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -547,9 +482,7 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { Level: LevelError, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -564,9 +497,7 @@ func (bl *BeeLogger) Warning(format string, v ...interface{}) { Level: LevelWarn, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -581,9 +512,7 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { Level: LevelNotice, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -598,9 +527,7 @@ func (bl *BeeLogger) Informational(format string, v ...interface{}) { Level: LevelInfo, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -615,9 +542,7 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) { Level: LevelDebug, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -633,9 +558,7 @@ func (bl *BeeLogger) Warn(format string, v ...interface{}) { Level: LevelWarn, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -651,9 +574,7 @@ func (bl *BeeLogger) Info(format string, v ...interface{}) { Level: LevelInfo, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -669,9 +590,7 @@ func (bl *BeeLogger) Trace(format string, v ...interface{}) { Level: LevelDebug, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) diff --git a/pkg/infrastructure/logs/log_formatter_test.go b/pkg/infrastructure/logs/log_formatter_test.go deleted file mode 100644 index 73281cf6..00000000 --- a/pkg/infrastructure/logs/log_formatter_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package logs - -import ( - "fmt" - "testing" - - "github.com/astaxie/beego/pkg/infrastructure/utils" -) - -func customFormatter(lm *LogMsg) string { - return fmt.Sprintf("[CUSTOM CONSOLE LOGGING] %s", lm.Msg) -} - -func globalFormatter(lm *LogMsg) string { - return fmt.Sprintf("[GLOBAL] %s", lm.Msg) -} - -func TestCustomLoggingFormatter(t *testing.T) { - // beego.BConfig.Log.AccessLogs = true - - SetLoggerWithOpts("console", []string{`{"color":true}`}, &utils.SimpleKV{Key: "formatter", Value: customFormatter}) - - // Message will be formatted by the customFormatter with colorful text set to true - Informational("Test message") -} - -func TestGlobalLoggingFormatter(t *testing.T) { - SetGlobalFormatter(globalFormatter) - - SetLogger("console", `{"color":true}`) - - // Message will be formatted by globalFormatter - Informational("Test message") - -} diff --git a/pkg/infrastructure/logs/multifile.go b/pkg/infrastructure/logs/multifile.go index bf589b91..79178211 100644 --- a/pkg/infrastructure/logs/multifile.go +++ b/pkg/infrastructure/logs/multifile.go @@ -16,8 +16,6 @@ package logs import ( "encoding/json" - - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // A filesLogWriter manages several fileLogWriter @@ -26,10 +24,9 @@ import ( // and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log // the rotate attribute also acts like fileLogWriter type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` - customFormatter func(*LogMsg) string + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` } var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} @@ -47,30 +44,27 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], // } -func (f *multiFileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - f.customFormatter = formatter - } - } +func (f *multiFileLogWriter) Init(config string) error { writer := newFileWriter().(*fileLogWriter) - err := writer.Init(jsonConfig) + err := writer.Init(config) if err != nil { return err } f.fullLogWriter = writer f.writers[LevelDebug+1] = writer - //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(jsonConfig), f) + // unmarshal "separate" field to f.Separate + err = json.Unmarshal([]byte(config), f) + if err != nil { + return err + } jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(jsonConfig), &jsonMap) + err = json.Unmarshal([]byte(config), &jsonMap) + if err != nil { + return err + } for i := LevelEmergency; i < LevelDebug+1; i++ { for _, v := range f.Separate { @@ -91,7 +85,11 @@ func (f *multiFileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { } func (f *multiFileLogWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() +} + +func (f *multiFileLogWriter) SetFormatter(fmt LogFormatter) { + f.fullLogWriter.SetFormatter(f) } func (f *multiFileLogWriter) Destroy() { @@ -126,7 +124,8 @@ func (f *multiFileLogWriter) Flush() { // newFilesWriter create a FileLogWriter returning as LoggerInterface. func newFilesWriter() Logger { - return &multiFileLogWriter{} + res := &multiFileLogWriter{} + return res } func init() { diff --git a/pkg/infrastructure/logs/slack.go b/pkg/infrastructure/logs/slack.go index d56b9acd..b6e2f170 100644 --- a/pkg/infrastructure/logs/slack.go +++ b/pkg/infrastructure/logs/slack.go @@ -6,35 +6,46 @@ import ( "net/http" "net/url" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook type SLACKWriter struct { - WebhookURL string `json:"webhookurl"` - Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` + formatter LogFormatter + Formatter string `json:"formatter"` } // newSLACKWriter creates jiaoliao writer. func newSLACKWriter() Logger { - return &SLACKWriter{Level: LevelTrace} + res := &SLACKWriter{Level: LevelTrace} + res.formatter = res + return res } func (s *SLACKWriter) Format(lm *LogMsg) string { - return lm.Msg + text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), lm.OldStyleFormat()) + return text +} + +func (s *SLACKWriter) SetFormatter(f LogFormatter) { + s.formatter = f } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonConfig string, opts ...utils.KV) error { - // if elem != nil { - // s.UseCustomFormatter = true - // s.CustomFormatter = elem - // } - // } +func (s *SLACKWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) - return json.Unmarshal([]byte(jsonConfig), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) + } + s.formatter = fmtr + } + + return res } // WriteMsg write message in smtp writer. @@ -44,10 +55,8 @@ func (s *SLACKWriter) WriteMsg(lm *LogMsg) error { return nil } msg := s.Format(lm) - text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), msg) - form := url.Values{} - form.Add("payload", text) + form.Add("payload", msg) resp, err := http.PostForm(s.WebhookURL, form) if err != nil { diff --git a/pkg/infrastructure/logs/smtp.go b/pkg/infrastructure/logs/smtp.go index 904a89df..40891a7c 100644 --- a/pkg/infrastructure/logs/smtp.go +++ b/pkg/infrastructure/logs/smtp.go @@ -22,7 +22,7 @@ import ( "net/smtp" "strings" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -34,12 +34,15 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` - customFormatter func(*LogMsg) string + formatter LogFormatter + Formatter string `json:"formatter"` } // NewSMTPWriter creates the smtp writer. func newSMTPWriter() Logger { - return &SMTPWriter{Level: LevelTrace} + res := &SMTPWriter{Level: LevelTrace} + res.formatter = res + return res } // Init smtp writer with json config. @@ -53,19 +56,16 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonConfig string, opts ...utils.KV) error { - - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - s.customFormatter = formatter +func (s *SMTPWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) } + s.formatter = fmtr } - - return json.Unmarshal([]byte(jsonConfig), s) + return res } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { @@ -80,6 +80,10 @@ func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { ) } +func (s *SMTPWriter) SetFormatter(f LogFormatter) { + s.formatter = f +} + func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { client, err := smtp.Dial(hostAddressWithPort) if err != nil { @@ -129,7 +133,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd } func (s *SMTPWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() } // WriteMsg writes message in smtp writer. From b575fa1ebe076bcc21d3bb73434aea26f0043191 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 11 Sep 2020 23:48:21 +0800 Subject: [PATCH 35/35] fix 4219 --- pkg/server/web/context/input.go | 2 +- pkg/server/web/router_test.go | 17 +++++++++++++++++ pkg/server/web/tree.go | 24 ++++++++++++------------ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/pkg/server/web/context/input.go b/pkg/server/web/context/input.go index a6fec774..f8657f84 100644 --- a/pkg/server/web/context/input.go +++ b/pkg/server/web/context/input.go @@ -89,7 +89,7 @@ func (input *BeegoInput) URI() string { // URL returns the request url path (without query, string and fragment). func (input *BeegoInput) URL() string { - return input.Context.Request.URL.EscapedPath() + return input.Context.Request.URL.Path } // Site returns the base site url as scheme://domain type. diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index 33b75703..2863da3a 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -212,6 +212,23 @@ func TestAutoExtFunc(t *testing.T) { } } +func TestEscape(t *testing.T) { + + r, _ := http.NewRequest("GET", "/search/%E4%BD%A0%E5%A5%BD", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Get("/search/:keyword(.+)", func(ctx *context.Context) { + value := ctx.Input.Param(":keyword") + ctx.Output.Body([]byte(value)) + }) + handler.ServeHTTP(w, r) + str := w.Body.String() + if str != "你好" { + t.Errorf("incorrect, %s", str) + } +} + func TestRouteOk(t *testing.T) { r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) diff --git a/pkg/server/web/tree.go b/pkg/server/web/tree.go index 7213a0c6..55f68076 100644 --- a/pkg/server/web/tree.go +++ b/pkg/server/web/tree.go @@ -33,13 +33,13 @@ var ( // wildcard stores params // leaves store the endpoint information type Tree struct { - //prefix set for static router + // prefix set for static router prefix string - //search fix route first + // search fix route first fixrouters []*Tree - //if set, failure to match fixrouters search then search wildcard + // if set, failure to match fixrouters search then search wildcard wildcard *Tree - //if set, failure to match wildcard search + // if set, failure to match wildcard search leaves []*leafInfo } @@ -69,13 +69,13 @@ func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg st filterTreeWithPrefix(tree, wildcards, reg) } } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -222,13 +222,13 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, t.addseg(segments[1:], route, wildcards, reg) params = params[1:] } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -393,7 +393,7 @@ type leafInfo struct { } func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { - //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) + // fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) if leaf.regexps == nil { if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path return true @@ -500,7 +500,7 @@ func splitSegment(key string) (bool, []string, string) { continue } if start { - //:id:int and :name:string + // :id:int and :name:string if v == ':' { if len(key) >= i+4 { if key[i+1:i+4] == "int" {