diff --git a/pkg/adapter/admin.go b/pkg/adapter/admin.go new file mode 100644 index 00000000..87e7259b --- /dev/null +++ b/pkg/adapter/admin.go @@ -0,0 +1,48 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + FilterMonitorFunc = web.FilterMonitorFunc +} + +// PrintTree prints all registered routers. +func PrintTree() M { + return (M)(web.PrintTree()) +} diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go new file mode 100644 index 00000000..c1046c79 --- /dev/null +++ b/pkg/adapter/app.go @@ -0,0 +1,262 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + context2 "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = (*App)(web.BeeApp) +} + +// App defines beego application with a new PatternServeMux. +type App web.App + +// NewApp returns a new beego application. +func NewApp() *App { + return (*App)(web.NewApp()) +} + +// MiddleWare function for http.Handler +type MiddleWare web.MiddleWare + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + (*web.App)(app).Run(newMws...) +} + +func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { + newMws := make([]web.MiddleWare, 0, len(mws)) + for _, old := range mws { + newMws = append(newMws, (web.MiddleWare)(old)) + } + return newMws +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + return (*App)(web.Router(rootpath, c, mappingMethods...)) +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + return (*App)(web.UnregisterFixedRoute(fixedRoute, method)) +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +// } +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + newList := oldToNewCtrlIntfs(cList) + return (*App)(web.Include(newList...)) +} + +func oldToNewCtrlIntfs(cList []ControllerInterface) []web.ControllerInterface { + newList := make([]web.ControllerInterface, 0, len(cList)) + for _, c := range cList { + newList = append(newList, c) + } + return newList +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + return (*App)(web.RESTRouter(rootpath, c)) +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + return (*App)(web.AutoRouter(c)) +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + return (*App)(web.AutoPrefix(prefix, c)) +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + return (*App)(web.Get(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + return (*App)(web.Post(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + return (*App)(web.Delete(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + return (*App)(web.Put(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + return (*App)(web.Head(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + return (*App)(web.Options(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + return (*App)(web.Patch(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + return (*App)(web.Any(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + return (*App)(web.Handler(rootpath, h, options)) +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + opts := oldToNewFilterOpts(params) + return (*App)(web.InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*context2.Context)(ctx)) + }, opts...)) +} diff --git a/pkg/adapter/beego.go b/pkg/adapter/beego.go new file mode 100644 index 00000000..efd2d4ea --- /dev/null +++ b/pkg/adapter/beego.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + // VERSION represent beego web framework version. + VERSION = web.VERSION + + // DEV is for develop + DEV = web.DEV + // PROD is for production + PROD = web.PROD +) + +// M is Map shortcut +type M web.M + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + for _, f := range hf { + web.AddAPPStartHook(func() error { + return f() + }) + } +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + web.Run(params...) +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + web.RunWithMiddleWares(addr, newMws...) +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + web.TestBeegoInit(ap) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + web.InitBeegoBeforeTest(appConfigPath) +} diff --git a/pkg/adapter/build_info.go b/pkg/adapter/build_info.go new file mode 100644 index 00000000..1e8dacf0 --- /dev/null +++ b/pkg/adapter/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/adapter/cache/cache.go b/pkg/adapter/cache/cache.go new file mode 100644 index 00000000..21bb9141 --- /dev/null +++ b/pkg/adapter/cache/cache.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + + "github.com/astaxie/beego/pkg/client/cache" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache cache.Cache + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/adapter/config.go b/pkg/adapter/config.go new file mode 100644 index 00000000..1491722c --- /dev/null +++ b/pkg/adapter/config.go @@ -0,0 +1,179 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + context2 "context" + + "github.com/astaxie/beego/pkg/adapter/session" + newCfg "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/server/web" +) + +// Config is the main struct for BConfig +type Config web.Config + +// Listen holds for http and https related config +type Listen web.Listen + +// WebConfig holds web related config +type WebConfig web.WebConfig + +// SessionConfig holds session related config +type SessionConfig web.SessionConfig + +// LogConfig holds Log related config +type LogConfig web.LogConfig + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = (*Config)(web.BConfig) + AppPath = web.AppPath + + WorkPath = web.WorkPath + + AppConfig = &beegoAppConfig{innerConfig: (newCfg.Configer)(web.AppConfig)} +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + return web.LoadAppConfig(adapterName, configPath) +} + +type beegoAppConfig struct { + innerConfig newCfg.Configer +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(context2.Background(), BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(context2.Background(), key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v, err := b.innerConfig.String(context2.Background(), BConfig.RunMode+"::"+key); v != "" && err != nil { + return v + } + res, _ := b.innerConfig.String(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v, err := b.innerConfig.Strings(context2.Background(), BConfig.RunMode+"::"+key); len(v) > 0 && err != nil { + return v + } + res, _ := b.innerConfig.Strings(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int(context2.Background(), key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int64(context2.Background(), key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Bool(context2.Background(), key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Float(context2.Background(), key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(context2.Background(), key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(context2.Background(), section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(context2.Background(), filename) +} diff --git a/pkg/adapter/config/adapter.go b/pkg/adapter/config/adapter.go new file mode 100644 index 00000000..f74b3ff9 --- /dev/null +++ b/pkg/adapter/config/adapter.go @@ -0,0 +1,193 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "context" + + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +type newToOldConfigerAdapter struct { + delegate config.Configer +} + +func (c *newToOldConfigerAdapter) Set(key, val string) error { + return c.delegate.Set(context.Background(), key, val) +} + +func (c *newToOldConfigerAdapter) String(key string) string { + res, _ := c.delegate.String(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Strings(key string) []string { + res, _ := c.delegate.Strings(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Int(key string) (int, error) { + return c.delegate.Int(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Int64(key string) (int64, error) { + return c.delegate.Int64(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Bool(key string) (bool, error) { + return c.delegate.Bool(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Float(key string) (float64, error) { + return c.delegate.Float(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) DefaultString(key string, defaultVal string) string { + return c.delegate.DefaultString(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultStrings(key string, defaultVal []string) []string { + return c.delegate.DefaultStrings(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt(key string, defaultVal int) int { + return c.delegate.DefaultInt(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt64(key string, defaultVal int64) int64 { + return c.delegate.DefaultInt64(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultBool(key string, defaultVal bool) bool { + return c.delegate.DefaultBool(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultFloat(key string, defaultVal float64) float64 { + return c.delegate.DefaultFloat(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DIY(key string) (interface{}, error) { + return c.delegate.DIY(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) GetSection(section string) (map[string]string, error) { + return c.delegate.GetSection(context.Background(), section) +} + +func (c *newToOldConfigerAdapter) SaveConfigFile(filename string) error { + return c.delegate.SaveConfigFile(context.Background(), filename) +} + +type oldToNewConfigerAdapter struct { + delegate Configer +} + +func (o *oldToNewConfigerAdapter) Set(ctx context.Context, key, val string) error { + return o.delegate.Set(key, val) +} + +func (o *oldToNewConfigerAdapter) String(ctx context.Context, key string) (string, error) { + return o.delegate.String(key), nil +} + +func (o *oldToNewConfigerAdapter) Strings(ctx context.Context, key string) ([]string, error) { + return o.delegate.Strings(key), nil +} + +func (o *oldToNewConfigerAdapter) Int(ctx context.Context, key string) (int, error) { + return o.delegate.Int(key) +} + +func (o *oldToNewConfigerAdapter) Int64(ctx context.Context, key string) (int64, error) { + return o.delegate.Int64(key) +} + +func (o *oldToNewConfigerAdapter) Bool(ctx context.Context, key string) (bool, error) { + return o.delegate.Bool(key) +} + +func (o *oldToNewConfigerAdapter) Float(ctx context.Context, key string) (float64, error) { + return o.delegate.Float(key) +} + +func (o *oldToNewConfigerAdapter) DefaultString(ctx context.Context, key string, defaultVal string) string { + return o.delegate.DefaultString(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + return o.delegate.DefaultStrings(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt(ctx context.Context, key string, defaultVal int) int { + return o.delegate.DefaultInt(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + return o.delegate.DefaultInt64(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + return o.delegate.DefaultBool(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + return o.delegate.DefaultFloat(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DIY(ctx context.Context, key string) (interface{}, error) { + return o.delegate.DIY(key) +} + +func (o *oldToNewConfigerAdapter) GetSection(ctx context.Context, section string) (map[string]string, error) { + return o.delegate.GetSection(section) +} + +func (o *oldToNewConfigerAdapter) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + return errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) Sub(ctx context.Context, key string) (config.Configer, error) { + return nil, errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) OnChange(ctx context.Context, key string, fn func(value string)) { + // do nothing +} + +func (o *oldToNewConfigerAdapter) SaveConfigFile(ctx context.Context, filename string) error { + return o.delegate.SaveConfigFile(filename) +} + +type oldToNewConfigAdapter struct { + delegate Config +} + +func (o *oldToNewConfigAdapter) Parse(key string) (config.Configer, error) { + old, err := o.delegate.Parse(key) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} + +func (o *oldToNewConfigAdapter) ParseData(data []byte) (config.Configer, error) { + old, err := o.delegate.ParseData(data) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} diff --git a/pkg/adapter/config/config.go b/pkg/adapter/config/config.go new file mode 100644 index 00000000..c870a15a --- /dev/null +++ b/pkg/adapter/config/config.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +// Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +// More docs http://beego.me/docs/module/config.md +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error // support section::key type in given key when using ini type. + String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string // get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string // get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + config.Register(name, &oldToNewConfigAdapter{delegate: adapter}) +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + cfg, err := config.NewConfig(adapterName, filename) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + cfg, err := config.NewConfigData(adapterName, data) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + return config.ExpandValueEnvForMap(m) +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) string { + return config.ExpandValueEnv(value) +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + return config.ParseBool(val) +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + return config.ToString(x) +} diff --git a/pkg/adapter/config/config_test.go b/pkg/adapter/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/adapter/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/adapter/config/env/env.go b/pkg/adapter/config/env/env.go new file mode 100644 index 00000000..77d7b53c --- /dev/null +++ b/pkg/adapter/config/env/env.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config/env" +) + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + return env.Get(key, defVal) +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + return env.MustGet(key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + return env.MustSet(key, value) +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + return env.GetAll() +} diff --git a/pkg/adapter/config/env/env_test.go b/pkg/adapter/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/adapter/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/adapter/config/fake.go b/pkg/adapter/config/fake.go new file mode 100644 index 00000000..fac96b41 --- /dev/null +++ b/pkg/adapter/config/fake.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + new := config.NewFakeConfig() + return &newToOldConfigerAdapter{delegate: new} +} diff --git a/pkg/adapter/config/ini_test.go b/pkg/adapter/config/ini_test.go new file mode 100644 index 00000000..ffcdb294 --- /dev/null +++ b/pkg/adapter/config/ini_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(k) + case int64: + value, err = iniconf.Int64(k) + case float64: + value, err = iniconf.Float(k) + case bool: + value, err = iniconf.Bool(k) + case []string: + value = iniconf.Strings(k) + case string: + value = iniconf.String(k) + default: + value, err = iniconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if iniconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/adapter/config/json.go b/pkg/adapter/config/json.go new file mode 100644 index 00000000..d0fe4d09 --- /dev/null +++ b/pkg/adapter/config/json.go @@ -0,0 +1,19 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/json" +) diff --git a/pkg/adapter/config/json_test.go b/pkg/adapter/config/json_test.go new file mode 100644 index 00000000..16f42409 --- /dev/null +++ b/pkg/adapter/config/json_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY("rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(k) + case int64: + value, err = jsonconf.Int64(k) + case float64: + value, err = jsonconf.Float(k) + case bool: + value, err = jsonconf.Bool(k) + case []string: + value = jsonconf.Strings(k) + case string: + value = jsonconf.String(k) + default: + value, err = jsonconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if jsonconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY("database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val := jsonconf.String("unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool("unknown", true) { + t.Error("unknown keys with default value wrong") + } +} diff --git a/pkg/adapter/config/xml/xml.go b/pkg/adapter/config/xml/xml.go new file mode 100644 index 00000000..f96cdcd6 --- /dev/null +++ b/pkg/adapter/config/xml/xml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +// More docs http://beego.me/docs/module/config.md +package xml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/xml" +) diff --git a/pkg/adapter/config/xml/xml_test.go b/pkg/adapter/config/xml/xml_test.go new file mode 100644 index 00000000..122c5027 --- /dev/null +++ b/pkg/adapter/config/xml/xml_test.go @@ -0,0 +1,125 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/pkg/adapter/config" +) + +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if xmlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } +} diff --git a/pkg/adapter/config/yaml/yaml.go b/pkg/adapter/config/yaml/yaml.go new file mode 100644 index 00000000..bc2398e9 --- /dev/null +++ b/pkg/adapter/config/yaml/yaml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +// More docs http://beego.me/docs/module/config.md +package yaml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/yaml" +) diff --git a/pkg/adapter/config/yaml/yaml_test.go b/pkg/adapter/config/yaml/yaml_test.go new file mode 100644 index 00000000..e4e309a2 --- /dev/null +++ b/pkg/adapter/config/yaml/yaml_test.go @@ -0,0 +1,115 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/pkg/adapter/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + if yamlconf.String("appname") != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if yamlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} diff --git a/pkg/adapter/context/acceptencoder.go b/pkg/adapter/context/acceptencoder.go new file mode 100644 index 00000000..e578de45 --- /dev/null +++ b/pkg/adapter/context/acceptencoder.go @@ -0,0 +1,45 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + context.InitGzip(minLength, compressLevel, methods) +} + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return context.WriteFile(encoding, writer, file) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + return context.WriteBody(encoding, writer, content) +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + return context.ParseEncoding(r) +} diff --git a/pkg/adapter/context/context.go b/pkg/adapter/context/context.go new file mode 100644 index 00000000..f9d8c624 --- /dev/null +++ b/pkg/adapter/context/context.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "net" + "net/http" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// commonly used mime-types +const ( + ApplicationJSON = context.ApplicationJSON + ApplicationXML = context.ApplicationXML + ApplicationYAML = context.ApplicationYAML + TextXML = context.TextXML +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return (*Context)(context.NewContext()) +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context context.Context + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + (*context.Context)(ctx).Reset(rw, r) +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + (*context.Context)(ctx).Redirect(status, localurl) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + (*context.Context)(ctx).Abort(status, body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + (*context.Context)(ctx).WriteString(content) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return (*context.Context)(ctx).GetCookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + (*context.Context)(ctx).SetCookie(name, value, others) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + return (*context.Context)(ctx).GetSecureCookie(Secret, key) +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*context.Context)(ctx).SetSecureCookie(Secret, name, value, others) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + return (*context.Context)(ctx).XSRFToken(key, expire) +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + return (*context.Context)(ctx).CheckXSRFCookie() +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + (*context.Context)(ctx).RenderMethodResult(result) +} + +// Response is a wrapper for the http.ResponseWriter +// started set to true if response was written to then don't execute other handler +type Response context.Response + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + return (*context.Response)(r).Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + (*context.Response)(r).WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return (*context.Response)(r).Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + (*context.Response)(r).Flush() +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + return (*context.Response)(r).CloseNotify() +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + return (*context.Response)(r).Pusher() +} diff --git a/pkg/adapter/context/input.go b/pkg/adapter/context/input.go new file mode 100644 index 00000000..a1d08855 --- /dev/null +++ b/pkg/adapter/context/input.go @@ -0,0 +1,282 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput context.BeegoInput + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return (*BeegoInput)(context.NewInput()) +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + (*context.BeegoInput)(input).Reset((*context.Context)(ctx)) +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return (*context.BeegoInput)(input).Protocol() +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return (*context.BeegoInput)(input).URL() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return (*context.BeegoInput)(input).Site() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + return (*context.BeegoInput)(input).Scheme() +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return (*context.BeegoInput)(input).Domain() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + return (*context.BeegoInput)(input).Host() +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return (*context.BeegoInput)(input).Method() +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return (*context.BeegoInput)(input).Is(method) +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return (*context.BeegoInput)(input).IsGet() +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return (*context.BeegoInput)(input).IsPost() +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return (*context.BeegoInput)(input).IsHead() +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return (*context.BeegoInput)(input).IsOptions() +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return (*context.BeegoInput)(input).IsPut() +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return (*context.BeegoInput)(input).IsDelete() +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return (*context.BeegoInput)(input).IsPatch() +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return (*context.BeegoInput)(input).IsAjax() +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return (*context.BeegoInput)(input).IsSecure() +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return (*context.BeegoInput)(input).IsWebsocket() +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return (*context.BeegoInput)(input).IsUpload() +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return (*context.BeegoInput)(input).AcceptsHTML() +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return (*context.BeegoInput)(input).AcceptsXML() +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return (*context.BeegoInput)(input).AcceptsJSON() +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return (*context.BeegoInput)(input).AcceptsYAML() +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + return (*context.BeegoInput)(input).IP() +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + return (*context.BeegoInput)(input).Proxy() +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return (*context.BeegoInput)(input).Referer() +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return (*context.BeegoInput)(input).Refer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + return (*context.BeegoInput)(input).SubDomains() +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + return (*context.BeegoInput)(input).Port() +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return (*context.BeegoInput)(input).UserAgent() +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return (*context.BeegoInput)(input).ParamsLen() +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + return (*context.BeegoInput)(input).Param(key) +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + return (*context.BeegoInput)(input).Params() +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + (*context.BeegoInput)(input).SetParam(key, val) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + (*context.BeegoInput)(input).ResetParams() +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + return (*context.BeegoInput)(input).Query(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return (*context.BeegoInput)(input).Header(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + return (*context.BeegoInput)(input).Cookie(key) +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return (*context.BeegoInput)(input).Session(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + return (*context.BeegoInput)(input).CopyBody(MaxMemory) +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + return (*context.BeegoInput)(input).Data() +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + return (*context.BeegoInput)(input).GetData(key) +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + (*context.BeegoInput)(input).SetData(key, val) +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + return (*context.BeegoInput)(input).ParseFormOrMulitForm(maxMemory) +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + return (*context.BeegoInput)(input).Bind(dest, key) +} diff --git a/pkg/adapter/context/output.go b/pkg/adapter/context/output.go new file mode 100644 index 00000000..8e2a7f7d --- /dev/null +++ b/pkg/adapter/context/output.go @@ -0,0 +1,154 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput context.BeegoOutput + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return (*BeegoOutput)(context.NewOutput()) +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + (*context.BeegoOutput)(output).Reset((*context.Context)(ctx)) +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + (*context.BeegoOutput)(output).Header(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + return (*context.BeegoOutput)(output).Body(content) +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + (*context.BeegoOutput)(output).Cookie(name, value, others) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + return (*context.BeegoOutput)(output).JSON(data, hasIndent, encoding) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + return (*context.BeegoOutput)(output).YAML(data) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).JSONP(data, hasIndent) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).XML(data, hasIndent) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + (*context.BeegoOutput)(output).ServeFormatted(data, hasIndent, hasEncode...) +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + (*context.BeegoOutput)(output).Download(file, filename...) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + (*context.BeegoOutput)(output).ContentType(ext) +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + (*context.BeegoOutput)(output).SetStatus(status) +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return (*context.BeegoOutput)(output).IsCachable() +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return (*context.BeegoOutput)(output).IsEmpty() +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return (*context.BeegoOutput)(output).IsOk() +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return (*context.BeegoOutput)(output).IsSuccessful() +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return (*context.BeegoOutput)(output).IsRedirect() +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return (*context.BeegoOutput)(output).IsForbidden() +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return (*context.BeegoOutput)(output).IsNotFound() +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return (*context.BeegoOutput)(output).IsClientError() +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return (*context.BeegoOutput)(output).IsServerError() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + (*context.BeegoOutput)(output).Session(name, value) +} diff --git a/pkg/adapter/context/renderer.go b/pkg/adapter/context/renderer.go new file mode 100644 index 00000000..763fb9c4 --- /dev/null +++ b/pkg/adapter/context/renderer.go @@ -0,0 +1,8 @@ +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// Renderer defines an http response renderer +type Renderer context.Renderer diff --git a/pkg/adapter/context/response.go b/pkg/adapter/context/response.go new file mode 100644 index 00000000..24e196a4 --- /dev/null +++ b/pkg/adapter/context/response.go @@ -0,0 +1,26 @@ +package context + +import ( + "net/http" + "strconv" +) + +const ( + // BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + // NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/pkg/adapter/controller.go b/pkg/adapter/controller.go new file mode 100644 index 00000000..010add64 --- /dev/null +++ b/pkg/adapter/controller.go @@ -0,0 +1,401 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "mime/multipart" + "net/url" + + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/session" + webContext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = web.ErrAbort + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = web.GlobalControllerRouter +) + +// ControllerFilter store the filter for controller +type ControllerFilter web.ControllerFilter + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments web.ControllerFilterComments + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments web.ControllerImportComments + +// ControllerComments store the comment for the controller method +type ControllerComments web.ControllerComments + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice web.ControllerCommentsSlice + +func (p ControllerCommentsSlice) Len() int { + return (web.ControllerCommentsSlice)(p).Len() +} +func (p ControllerCommentsSlice) Less(i, j int) bool { + return (web.ControllerCommentsSlice)(p).Less(i, j) +} +func (p ControllerCommentsSlice) Swap(i, j int) { + (web.ControllerCommentsSlice)(p).Swap(i, j) +} + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller web.Controller + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface web.ControllerInterface + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + (*web.Controller)(c).Init((*webContext.Context)(ctx), controllerName, actionName, app) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() { + (*web.Controller)(c).Prepare() +} + +// Finish runs after request function execution. +func (c *Controller) Finish() { + (*web.Controller)(c).Finish() +} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + (*web.Controller)(c).Get() +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + (*web.Controller)(c).Post() +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + (*web.Controller)(c).Delete() +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + (*web.Controller)(c).Put() +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + (*web.Controller)(c).Head() +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + (*web.Controller)(c).Patch() +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + (*web.Controller)(c).Options() +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + (*web.Controller)(c).Trace() +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + return (*web.Controller)(c).HandlerFunc(fnname) +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() { + (*web.Controller)(c).URLMapping() +} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + (*web.Controller)(c).Mapping(method, fn) +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + return (*web.Controller)(c).Render() +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + return (*web.Controller)(c).RenderString() +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + return (*web.Controller)(c).RenderBytes() +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + (*web.Controller)(c).Redirect(url, code) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + (*web.Controller)(c).SetData(data) +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + (*web.Controller)(c).Abort(code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + (*web.Controller)(c).CustomAbort(status, body) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + (*web.Controller)(c).StopRun() +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + return (*web.Controller)(c).URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + (*web.Controller)(c).ServeJSON(encoding...) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + (*web.Controller)(c).ServeJSONP() +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + (*web.Controller)(c).ServeXML() +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + (*web.Controller)(c).ServeYAML() +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + (*web.Controller)(c).ServeFormatted(encoding...) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + return (*web.Controller)(c).Input() +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return (*web.Controller)(c).ParseForm(obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + return (*web.Controller)(c).GetString(key, def...) +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + return (*web.Controller)(c).GetStrings(key, def...) +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + return (*web.Controller)(c).GetInt(key, def...) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + return (*web.Controller)(c).GetInt8(key, def...) +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + return (*web.Controller)(c).GetUint8(key, def...) +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + return (*web.Controller)(c).GetInt16(key, def...) +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + return (*web.Controller)(c).GetUint16(key, def...) +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + return (*web.Controller)(c).GetInt32(key, def...) +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + return (*web.Controller)(c).GetUint32(key, def...) +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + return (*web.Controller)(c).GetInt64(key, def...) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + return (*web.Controller)(c).GetUint64(key, def...) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + return (*web.Controller)(c).GetBool(key, def...) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + return (*web.Controller)(c).GetFloat(key, def...) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return (*web.Controller)(c).GetFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + return (*web.Controller)(c).GetFiles(key) +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + return (*web.Controller)(c).SaveToFile(fromfile, tofile) +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + s := (*web.Controller)(c).StartSession() + return session.CreateNewToOldStoreAdapter(s) +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + (*web.Controller)(c).SetSession(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + return (*web.Controller)(c).GetSession(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + (*web.Controller)(c).DelSession(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + (*web.Controller)(c).SessionRegenerateID() +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + (*web.Controller)(c).DestroySession() +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return (*web.Controller)(c).IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return (*web.Controller)(c).GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*web.Controller)(c).SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + return (*web.Controller)(c).XSRFToken() +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + return (*web.Controller)(c).CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return (*web.Controller)(c).XSRFFormHTML() +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return (*web.Controller)(c).GetControllerAndAction() +} diff --git a/pkg/adapter/error.go b/pkg/adapter/error.go new file mode 100644 index 00000000..4f08aa8c --- /dev/null +++ b/pkg/adapter/error.go @@ -0,0 +1,202 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = web.ErrorMaps + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + return (*App)(web.ErrorHandler(code, h)) +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + return (*App)(web.ErrorController(c)) +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + web.Exception(errCode, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/filter.go b/pkg/adapter/filter.go new file mode 100644 index 00000000..cafed773 --- /dev/null +++ b/pkg/adapter/filter.go @@ -0,0 +1,36 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter web.FilterRouter + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + return (*web.FilterRouter)(f).ValidRouter(url, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/flash.go b/pkg/adapter/flash.go new file mode 100644 index 00000000..02e75ed6 --- /dev/null +++ b/pkg/adapter/flash.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData web.FlashData + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return (*FlashData)(web.NewFlash()) +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + (*web.FlashData)(fd).Set(key, msg, args...) +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + (*web.FlashData)(fd).Success(msg, args...) +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + (*web.FlashData)(fd).Notice(msg, args...) +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + (*web.FlashData)(fd).Warning(msg, args...) +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + (*web.FlashData)(fd).Error(msg, args...) +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + (*web.FlashData)(fd).Store((*web.Controller)(c)) +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + return (*FlashData)(web.ReadFromRequest((*web.Controller)(c))) +} diff --git a/pkg/adapter/fs.go b/pkg/adapter/fs.go new file mode 100644 index 00000000..07054ca3 --- /dev/null +++ b/pkg/adapter/fs.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "path/filepath" + + "github.com/astaxie/beego/pkg/server/web" +) + +type FileSystem web.FileSystem + +func (d FileSystem) Open(name string) (http.File, error) { + return (web.FileSystem)(d).Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + return web.Walk(fs, root, walkFn) +} diff --git a/pkg/adapter/grace/grace.go b/pkg/adapter/grace/grace.go new file mode 100644 index 00000000..3775e395 --- /dev/null +++ b/pkg/adapter/grace/grace.go @@ -0,0 +1,94 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = grace.DefaultTimeout +) + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + return (*Server)(grace.NewServer(addr, handler)) +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/adapter/grace/server.go b/pkg/adapter/grace/server.go new file mode 100644 index 00000000..31c13f18 --- /dev/null +++ b/pkg/adapter/grace/server.go @@ -0,0 +1,48 @@ +package grace + +import ( + "os" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +// Server embedded http.Server +type Server grace.Server + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + return (*grace.Server)(srv).Serve() +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + return (*grace.Server)(srv).ListenAndServe() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + return (*grace.Server)(srv).ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) error { + return (*grace.Server)(srv).ListenAndServeMutualTLS(certFile, keyFile, trustFile) +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) error { + return (*grace.Server)(srv).RegisterSignalHook(ppFlag, sig, f) +} diff --git a/pkg/adapter/httplib/httplib.go b/pkg/adapter/httplib/httplib.go new file mode 100644 index 00000000..d2ef36c1 --- /dev/null +++ b/pkg/adapter/httplib/httplib.go @@ -0,0 +1,300 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/client/httplib" +) + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + httplib.SetDefaultSetting(httplib.BeegoHTTPSettings(setting)) +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + return &BeegoHTTPRequest{ + delegate: httplib.NewBeegoRequest(rawurl, method), + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings httplib.BeegoHTTPSettings + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + delegate *httplib.BeegoHTTPRequest +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.delegate.GetRequest() +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.delegate.Setting(httplib.BeegoHTTPSettings(setting)) + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.delegate.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.delegate.SetEnableCookie(enable) + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.delegate.SetUserAgent(useragent) + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.delegate.Debug(isdebug) + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.delegate.Retries(times) + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.delegate.RetryDelay(delay) + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.delegate.DumpBody(isdump) + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.delegate.DumpRequest() +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.delegate.SetTimeout(connectTimeout, readWriteTimeout) + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.delegate.SetTLSClientConfig(config) + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.delegate.Header(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.delegate.SetHost(host) + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + b.delegate.SetProtocolVersion(vers) + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.delegate.SetCookie(cookie) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.delegate.SetTransport(transport) + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.delegate.SetProxy(proxy) + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.delegate.SetCheckRedirect(redirect) + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + b.delegate.Param(key, value) + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.delegate.PostFile(formname, filename) + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + b.delegate.Body(data) + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.XMLBody(obj) + return b, err +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.YAMLBody(obj) + return b, err +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.JSONBody(obj) + return b, err +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + return b.delegate.DoRequest() +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + return b.delegate.String() +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + return b.delegate.Bytes() +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + return b.delegate.ToFile(filename) +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + return b.delegate.ToJSON(v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + return b.delegate.ToXML(v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + return b.delegate.ToYAML(v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.delegate.Response() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return httplib.TimeoutDialer(cTimeout, rwTimeout) +} diff --git a/pkg/adapter/httplib/httplib_test.go b/pkg/adapter/httplib/httplib_test.go new file mode 100644 index 00000000..e7605c87 --- /dev/null +++ b/pkg/adapter/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.SetCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + }) + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +// func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +// } + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} diff --git a/pkg/adapter/log.go b/pkg/adapter/log.go new file mode 100644 index 00000000..d9ff6e0c --- /dev/null +++ b/pkg/adapter/log.go @@ -0,0 +1,129 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "strings" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + + webLog "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = webLog.LevelEmergency + LevelAlert = webLog.LevelAlert + LevelCritical = webLog.LevelCritical + LevelError = webLog.LevelError + LevelWarning = webLog.LevelWarning + LevelNotice = webLog.LevelNotice + LevelInformational = webLog.LevelInformational + LevelDebug = webLog.LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/adapter/metric/prometheus.go b/pkg/adapter/metric/prometheus.go new file mode 100644 index 00000000..1d3488c6 --- /dev/null +++ b/pkg/adapter/metric/prometheus.go @@ -0,0 +1,99 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/server/web" +) + +func PrometheusMiddleWare(next http.Handler) http.Handler { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status", "duration"}) + + prometheus.MustRegister(summaryVec) + + registerBuildInfo() + + return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { + start := time.Now() + next.ServeHTTP(writer, q) + end := time.Now() + go report(end.Sub(start), writer, q, summaryVec) + }) +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": web.BConfig.AppName, + "build_version": web.BuildVersion, + "build_revision": web.BuildGitRevision, + "build_status": web.BuildStatus, + "build_tag": web.BuildTag, + "build_time": strings.Replace(web.BuildTime, "--", " ", 1), + "go_version": web.GoVersion, + "git_branch": web.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { + ctrl := web.BeeApp.Handlers + ctx := ctrl.GetContext() + ctx.Reset(writer, q) + defer ctrl.GiveBackContext(ctx) + + // We cannot read the status code from q.Response.StatusCode + // since the http server does not set q.Response. So q.Response is nil + // Thus, we use reflection to read the status from writer whose concrete type is http.response + responseVal := reflect.ValueOf(writer).Elem() + field := responseVal.FieldByName("status") + status := -1 + if field.IsValid() && field.Kind() == reflect.Int { + status = int(field.Int()) + } + ptn := "UNKNOWN" + if rt, found := ctrl.FindRouter(ctx); found { + ptn = rt.GetPattern() + } else { + logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) +} diff --git a/pkg/adapter/metric/prometheus_test.go b/pkg/adapter/metric/prometheus_test.go new file mode 100644 index 00000000..87286e02 --- /dev/null +++ b/pkg/adapter/metric/prometheus_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/pkg/adapter/context" +) + +func TestPrometheusMiddleWare(t *testing.T) { + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + writer := &context.Response{} + request := &http.Request{ + URL: &url.URL{ + Host: "localhost", + RawPath: "/a/b/c", + }, + Method: "POST", + } + vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) + + report(time.Second, writer, request, vec) + middleware.ServeHTTP(writer, request) +} diff --git a/pkg/adapter/migration/ddl.go b/pkg/adapter/migration/ddl.go new file mode 100644 index 00000000..97e45dec --- /dev/null +++ b/pkg/adapter/migration/ddl.go @@ -0,0 +1,198 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// Index struct defines the structure of Index Columns +type Index migration.Index + +// Unique struct defines a single unique key combination +type Unique migration.Unique + +// Column struct defines a single column of a table +type Column migration.Column + +// Foreign struct defines a single foreign relationship +type Foreign migration.Foreign + +// RenameColumn struct allows renaming of columns +type RenameColumn migration.RenameColumn + +// CreateTable creates the table on system +func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { + (*migration.Migration)(m).CreateTable(tablename, engine, charset, p...) +} + +// AlterTable set the ModifyType to alter +func (m *Migration) AlterTable(tablename string) { + (*migration.Migration)(m).AlterTable(tablename) +} + +// NewCol creates a new standard column and attaches it to m struct +func (m *Migration) NewCol(name string) *Column { + return (*Column)((*migration.Migration)(m).NewCol(name)) +} + +// PriCol creates a new primary column and attaches it to m struct +func (m *Migration) PriCol(name string) *Column { + return (*Column)((*migration.Migration)(m).PriCol(name)) +} + +// UniCol creates / appends columns to specified unique key and attaches it to m struct +func (m *Migration) UniCol(uni, name string) *Column { + return (*Column)((*migration.Migration)(m).UniCol(uni, name)) +} + +// ForeignCol creates a new foreign column and returns the instance of column +func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { + return (*Foreign)((*migration.Migration)(m).ForeignCol(colname, foreigncol, foreigntable)) +} + +// SetOnDelete sets the on delete of foreign +func (foreign *Foreign) SetOnDelete(del string) *Foreign { + (*migration.Foreign)(foreign).SetOnDelete(del) + return foreign +} + +// SetOnUpdate sets the on update of foreign +func (foreign *Foreign) SetOnUpdate(update string) *Foreign { + (*migration.Foreign)(foreign).SetOnUpdate(update) + return foreign +} + +// Remove marks the columns to be removed. +// it allows reverse m to create the column. +func (c *Column) Remove() { + (*migration.Column)(c).Remove() +} + +// SetAuto enables auto_increment of column (can be used once) +func (c *Column) SetAuto(inc bool) *Column { + (*migration.Column)(c).SetAuto(inc) + return c +} + +// SetNullable sets the column to be null +func (c *Column) SetNullable(null bool) *Column { + (*migration.Column)(c).SetNullable(null) + return c +} + +// SetDefault sets the default value, prepend with "DEFAULT " +func (c *Column) SetDefault(def string) *Column { + (*migration.Column)(c).SetDefault(def) + return c +} + +// SetUnsigned sets the column to be unsigned int +func (c *Column) SetUnsigned(unsign bool) *Column { + (*migration.Column)(c).SetUnsigned(unsign) + return c +} + +// SetDataType sets the dataType of the column +func (c *Column) SetDataType(dataType string) *Column { + (*migration.Column)(c).SetDataType(dataType) + return c +} + +// SetOldNullable allows reverting to previous nullable on reverse ms +func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldNullable(null) + return c +} + +// SetOldDefault allows reverting to previous default on reverse ms +func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDefault(def) + return c +} + +// SetOldUnsigned allows reverting to previous unsgined on reverse ms +func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldUnsigned(unsign) + return c +} + +// SetOldDataType allows reverting to previous datatype on reverse ms +func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDataType(dataType) + return c +} + +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +func (c *Column) SetPrimary(m *Migration) *Column { + (*migration.Column)(c).SetPrimary((*migration.Migration)(m)) + return c +} + +// AddColumnsToUnique adds the columns to Unique Struct +func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { + cls := toNewColumnsArray(columns) + (*migration.Unique)(unique).AddColumnsToUnique(cls...) + return unique +} + +// AddColumns adds columns to m struct +func (m *Migration) AddColumns(columns ...*Column) *Migration { + cls := toNewColumnsArray(columns) + (*migration.Migration)(m).AddColumns(cls...) + return m +} + +func toNewColumnsArray(columns []*Column) []*migration.Column { + cls := make([]*migration.Column, 0, len(columns)) + for _, c := range columns { + cls = append(cls, (*migration.Column)(c)) + } + return cls +} + +// AddPrimary adds the column to primary in m struct +func (m *Migration) AddPrimary(primary *Column) *Migration { + (*migration.Migration)(m).AddPrimary((*migration.Column)(primary)) + return m +} + +// AddUnique adds the column to unique in m struct +func (m *Migration) AddUnique(unique *Unique) *Migration { + (*migration.Migration)(m).AddUnique((*migration.Unique)(unique)) + return m +} + +// AddForeign adds the column to foreign in m struct +func (m *Migration) AddForeign(foreign *Foreign) *Migration { + (*migration.Migration)(m).AddForeign((*migration.Foreign)(foreign)) + return m +} + +// AddIndex adds the column to index in m struct +func (m *Migration) AddIndex(index *Index) *Migration { + (*migration.Migration)(m).AddIndex((*migration.Index)(index)) + return m +} + +// RenameColumn allows renaming of columns +func (m *Migration) RenameColumn(from, to string) *RenameColumn { + return (*RenameColumn)((*migration.Migration)(m).RenameColumn(from, to)) +} + +// GetSQL returns the generated sql depending on ModifyType +func (m *Migration) GetSQL() (sql string) { + return (*migration.Migration)(m).GetSQL() +} diff --git a/pkg/adapter/migration/doc.go b/pkg/adapter/migration/doc.go new file mode 100644 index 00000000..0c6564d4 --- /dev/null +++ b/pkg/adapter/migration/doc.go @@ -0,0 +1,32 @@ +// Package migration enables you to generate migrations back and forth. It generates both migrations. +// +// //Creates a table +// m.CreateTable("tablename","InnoDB","utf8"); +// +// //Alter a table +// m.AlterTable("tablename") +// +// Standard Column Methods +// * SetDataType +// * SetNullable +// * SetDefault +// * SetUnsigned (use only on integer types unless produces error) +// +// //Sets a primary column, multiple calls allowed, standard column methods available +// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) +// +// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index +// m.UniCol("index","column") +// +// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove +// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) +// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) +// +// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to +// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") +// m.RenameColumn("from","to")... +// +// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. +// //Supports standard column methods, automatic reverse. +// m.ForeignCol("local_col","foreign_col","foreign_table") +package migration diff --git a/pkg/adapter/migration/migration.go b/pkg/adapter/migration/migration.go new file mode 100644 index 00000000..4ee22e5a --- /dev/null +++ b/pkg/adapter/migration/migration.go @@ -0,0 +1,111 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package migration is used for migration +// +// The table structure is as follow: +// +// CREATE TABLE `migrations` ( +// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', +// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', +// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', +// `statements` longtext COMMENT 'SQL statements for this migration', +// `rollback_statements` longtext, +// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', +// PRIMARY KEY (`id_migration`) +// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// const the data format for the bee generate migration datatype +const ( + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" +) + +// Migrationer is an interface for all Migration struct +type Migrationer interface { + Up() + Down() + Reset() + Exec(name, status string) error + GetCreated() int64 +} + +// Migration defines the migrations by either SQL or DDL +type Migration migration.Migration + +// Up implement in the Inheritance struct for upgrade +func (m *Migration) Up() { + (*migration.Migration)(m).Up() +} + +// Down implement in the Inheritance struct for down +func (m *Migration) Down() { + (*migration.Migration)(m).Down() +} + +// Migrate adds the SQL to the execution list +func (m *Migration) Migrate(migrationType string) { + (*migration.Migration)(m).Migrate(migrationType) +} + +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { + (*migration.Migration)(m).SQL(sql) +} + +// Reset the sqls +func (m *Migration) Reset() { + (*migration.Migration)(m).Reset() +} + +// Exec execute the sql already add in the sql +func (m *Migration) Exec(name, status string) error { + return (*migration.Migration)(m).Exec(name, status) +} + +// GetCreated get the unixtime from the Created +func (m *Migration) GetCreated() int64 { + return (*migration.Migration)(m).GetCreated() +} + +// Register register the Migration in the map +func Register(name string, m Migrationer) error { + return migration.Register(name, m) +} + +// Upgrade upgrade the migration from lasttime +func Upgrade(lasttime int64) error { + return migration.Upgrade(lasttime) +} + +// Rollback rollback the migration by the name +func Rollback(name string) error { + return migration.Rollback(name) +} + +// Reset reset all migration +// run all migration's down function +func Reset() error { + return migration.Reset() +} + +// Refresh first Reset, then Upgrade +func Refresh() error { + return migration.Refresh() +} diff --git a/pkg/adapter/namespace.go b/pkg/adapter/namespace.go new file mode 100644 index 00000000..609402cf --- /dev/null +++ b/pkg/adapter/namespace.go @@ -0,0 +1,378 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + adtContext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +type namespaceCond func(*adtContext.Context) bool + +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) + +// Namespace is store all the info +type Namespace web.Namespace + +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { + nps := oldToNewLinkNs(params) + return (*Namespace)(web.NewNamespace(prefix, nps...)) +} + +func oldToNewLinkNs(params []LinkNamespace) []web.LinkNamespace { + nps := make([]web.LinkNamespace, 0, len(params)) + for _, p := range params { + nps = append(nps, func(namespace *web.Namespace) { + p((*Namespace)(namespace)) + }) + } + return nps +} + +// Cond set condition function +// if cond return true can run this namespace, else can't +// usage: +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.me" { +// return true +// } +// return false +// }) +// Cond as the first filter +func (n *Namespace) Cond(cond namespaceCond) *Namespace { + (*web.Namespace)(n).Cond(func(context *context.Context) bool { + return cond((*adtContext.Context)(context)) + }) + return n +} + +// Filter add filter in the Namespace +// action has before & after +// FilterFunc +// usage: +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) +func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { + nfs := oldToNewFilter(filter) + (*web.Namespace)(n).Filter(action, nfs...) + return n +} + +func oldToNewFilter(filter []FilterFunc) []web.FilterFunc { + nfs := make([]web.FilterFunc, 0, len(filter)) + for _, f := range filter { + nfs = append(nfs, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } + return nfs +} + +// Router same as beego.Rourer +// refer: https://godoc.org/github.com/astaxie/beego#Router +func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { + (*web.Namespace)(n).Router(rootpath, c, mappingMethods...) + return n +} + +// AutoRouter same as beego.AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoRouter(c) + return n +} + +// AutoPrefix same as beego.AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoPrefix(prefix, c) + return n +} + +// Get same as beego.Get +// refer: https://godoc.org/github.com/astaxie/beego#Get +func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Get(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Post same as beego.Post +// refer: https://godoc.org/github.com/astaxie/beego#Post +func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Delete same as beego.Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete +func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Delete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Put same as beego.Put +// refer: https://godoc.org/github.com/astaxie/beego#Put +func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Put(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Head same as beego.Head +// refer: https://godoc.org/github.com/astaxie/beego#Head +func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Head(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Options same as beego.Options +// refer: https://godoc.org/github.com/astaxie/beego#Options +func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Options(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Patch same as beego.Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch +func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Patch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Any same as beego.Any +// refer: https://godoc.org/github.com/astaxie/beego#Any +func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Any(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Handler same as beego.Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler +func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { + (*web.Namespace)(n).Handler(rootpath, h) + return n +} + +// Include add include class +// refer: https://godoc.org/github.com/astaxie/beego#Include +func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { + nL := oldToNewCtrlIntfs(cList) + (*web.Namespace)(n).Include(nL...) + return n +} + +// Namespace add nest Namespace +// usage: +// ns := beego.NewNamespace(“/v1”). +// Namespace( +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +// ) +func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { + nns := oldToNewNs(ns) + (*web.Namespace)(n).Namespace(nns...) + return n +} + +func oldToNewNs(ns []*Namespace) []*web.Namespace { + nns := make([]*web.Namespace, 0, len(ns)) + for _, n := range ns { + nns = append(nns, (*web.Namespace)(n)) + } + return nns +} + +// AddNamespace register Namespace into beego.Handler +// support multi Namespace +func AddNamespace(nl ...*Namespace) { + nnl := oldToNewNs(nl) + web.AddNamespace(nnl...) +} + +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { + return func(namespace *Namespace) { + web.NSCond(func(b *context.Context) bool { + return cond((*adtContext.Context)(b)) + }) + } +} + +// NSBefore Namespace BeforeRouter filter +func NSBefore(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSBefore(nfs...) + } +} + +// NSAfter add Namespace FinishRouter filter +func NSAfter(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSAfter(nfs...) + } +} + +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewCtrlIntfs(cList) + web.NSInclude(nfs...) + } +} + +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + return func(namespace *Namespace) { + web.Router(rootpath, c, mappingMethods...) + } +} + +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSGet(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSHead(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPut(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSDelete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSAny(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSOptions(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPatch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoRouter(c) + } +} + +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoPrefix(prefix, c) + } +} + +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + return func(ns *Namespace) { + nps := oldToNewLinkNs(params) + web.NSNamespace(prefix, nps...) + } +} + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + web.NSHandler(rootpath, h) + } +} diff --git a/pkg/adapter/orm/cmd.go b/pkg/adapter/orm/cmd.go new file mode 100644 index 00000000..6fee237c --- /dev/null +++ b/pkg/adapter/orm/cmd.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RunCommand listen for orm command and then run it if command arguments passed. +func RunCommand() { + orm.RunCommand() +} + +func RunSyncdb(name string, force bool, verbose bool) error { + return orm.RunSyncdb(name, force, verbose) +} diff --git a/pkg/adapter/orm/db.go b/pkg/adapter/orm/db.go new file mode 100644 index 00000000..74bca8c0 --- /dev/null +++ b/pkg/adapter/orm/db.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +var ( + // ErrMissPK missing pk error + ErrMissPK = orm.ErrMissPK +) diff --git a/pkg/adapter/orm/db_alias.go b/pkg/adapter/orm/db_alias.go new file mode 100644 index 00000000..b1f1a724 --- /dev/null +++ b/pkg/adapter/orm/db_alias.go @@ -0,0 +1,122 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// DriverType database driver constant int. +type DriverType orm.DriverType + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL = orm.DRMySQL + DRSqlite = orm.DRSqlite // sqlite + DROracle = orm.DROracle // oracle + DRPostgres = orm.DRPostgres // pgsql + DRTiDB = orm.DRTiDB // TiDB +) + +type DB orm.DB + +func (d *DB) Begin() (*sql.Tx, error) { + return (*orm.DB)(d).Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return (*orm.DB)(d).BeginTx(ctx, opts) +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return (*orm.DB)(d).Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return (*orm.DB)(d).PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).Exec(query, args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).ExecContext(ctx, query, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).Query(query, args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).QueryContext(ctx, query, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRow(query, args) +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRowContext(ctx, query, args...) +} + +// AddAliasWthDB add a aliasName for the drivename +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + return orm.AddAliasWthDB(aliasName, driverName, db) +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + opts := make([]orm.DBOption, 0, 2) + if len(params) > 0 { + opts = append(opts, orm.MaxIdleConnections(params[0])) + } + + if len(params) > 1 { + opts = append(opts, orm.MaxOpenConnections(params[1])) + } + return orm.RegisterDataBase(aliasName, driverName, dataSource, opts...) +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + return orm.RegisterDriver(driverName, orm.DriverType(typ)) +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) error { + return orm.SetDataBaseTZ(aliasName, tz) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + orm.SetMaxIdleConns(aliasName, maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + orm.SetMaxOpenConns(aliasName, maxOpenConns) +} + +// GetDB Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + return orm.GetDB(aliasNames...) +} diff --git a/pkg/adapter/orm/models.go b/pkg/adapter/orm/models.go new file mode 100644 index 00000000..3215f5b5 --- /dev/null +++ b/pkg/adapter/orm/models.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + orm.ResetModelCache() +} diff --git a/pkg/adapter/orm/models_boot.go b/pkg/adapter/orm/models_boot.go new file mode 100644 index 00000000..8888ef65 --- /dev/null +++ b/pkg/adapter/orm/models_boot.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + orm.RegisterModel(models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + orm.RegisterModelWithPrefix(prefix, models) +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + orm.RegisterModelWithSuffix(suffix, models...) +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + orm.BootStrap() +} diff --git a/pkg/adapter/orm/models_fields.go b/pkg/adapter/orm/models_fields.go new file mode 100644 index 00000000..666a97dc --- /dev/null +++ b/pkg/adapter/orm/models_fields.go @@ -0,0 +1,625 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Define the Type enum +const ( + TypeBooleanField = orm.TypeBooleanField + TypeVarCharField = orm.TypeVarCharField + TypeCharField = orm.TypeCharField + TypeTextField = orm.TypeTextField + TypeTimeField = orm.TypeTimeField + TypeDateField = orm.TypeDateField + TypeDateTimeField = orm.TypeDateTimeField + TypeBitField = orm.TypeBitField + TypeSmallIntegerField = orm.TypeSmallIntegerField + TypeIntegerField = orm.TypeIntegerField + TypeBigIntegerField = orm.TypeBigIntegerField + TypePositiveBitField = orm.TypePositiveBitField + TypePositiveSmallIntegerField = orm.TypePositiveSmallIntegerField + TypePositiveIntegerField = orm.TypePositiveIntegerField + TypePositiveBigIntegerField = orm.TypePositiveBigIntegerField + TypeFloatField = orm.TypeFloatField + TypeDecimalField = orm.TypeDecimalField + TypeJSONField = orm.TypeJSONField + TypeJsonbField = orm.TypeJsonbField + RelForeignKey = orm.RelForeignKey + RelOneToOne = orm.RelOneToOne + RelManyToMany = orm.RelManyToMany + RelReverseOne = orm.RelReverseOne + RelReverseMany = orm.RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = orm.IsIntegerField + IsPositiveIntegerField = orm.IsPositiveIntegerField + IsRelField = orm.IsRelField + IsFieldType = orm.IsFieldType +) + +// BooleanField A true/false field. +type BooleanField orm.BooleanField + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return orm.BooleanField(e).Value() +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + (*orm.BooleanField)(e).Set(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return (*orm.BooleanField)(e).String() +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return (*orm.BooleanField)(e).FieldType() +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + return (*orm.BooleanField)(e).SetRaw(value) +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return (*orm.BooleanField)(e).RawValue() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField orm.CharField + +// Value return the CharField's Value +func (e CharField) Value() string { + return orm.CharField(e).Value() +} + +// Set CharField value +func (e *CharField) Set(d string) { + (*orm.CharField)(e).Set(d) +} + +// String return the CharField +func (e *CharField) String() string { + return (*orm.CharField)(e).String() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return (*orm.CharField)(e).FieldType() +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + return (*orm.CharField)(e).SetRaw(value) +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return (*orm.CharField)(e).RawValue() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField orm.TimeField + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return orm.TimeField(e).Value() +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + (*orm.TimeField)(e).Set(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return (*orm.TimeField)(e).String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return (*orm.TimeField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + return (*orm.TimeField)(e).SetRaw(value) +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return (*orm.TimeField)(e).RawValue() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField orm.DateField + +// Value return the time.Time +func (e DateField) Value() time.Time { + return orm.DateField(e).Value() +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + (*orm.DateField)(e).Set(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return (*orm.DateField)(e).String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return (*orm.DateField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + return (*orm.DateField)(e).SetRaw(value) +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return (*orm.DateField)(e).RawValue() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField orm.DateTimeField + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return orm.DateTimeField(e).Value() +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + (*orm.DateTimeField)(e).Set(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return (*orm.DateTimeField)(e).String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return (*orm.DateTimeField)(e).FieldType() +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + return (*orm.DateTimeField)(e).SetRaw(value) +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return (*orm.DateTimeField)(e).RawValue() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField orm.FloatField + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return orm.FloatField(e).Value() +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + (*orm.FloatField)(e).Set(d) +} + +// String return the string +func (e *FloatField) String() string { + return (*orm.FloatField)(e).String() +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return (*orm.FloatField)(e).FieldType() +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + return (*orm.FloatField)(e).SetRaw(value) +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return (*orm.FloatField)(e).RawValue() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField orm.SmallIntegerField + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return orm.SmallIntegerField(e).Value() +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + (*orm.SmallIntegerField)(e).Set(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return (*orm.SmallIntegerField)(e).String() +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return (*orm.SmallIntegerField)(e).FieldType() +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + return (*orm.SmallIntegerField)(e).SetRaw(value) +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return (*orm.SmallIntegerField)(e).RawValue() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField orm.IntegerField + +// Value return the int32 +func (e IntegerField) Value() int32 { + return orm.IntegerField(e).Value() +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + (*orm.IntegerField)(e).Set(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return (*orm.IntegerField)(e).String() +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return (*orm.IntegerField)(e).FieldType() +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + return (*orm.IntegerField)(e).SetRaw(value) +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return (*orm.IntegerField)(e).RawValue() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField orm.BigIntegerField + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return orm.BigIntegerField(e).Value() +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + (*orm.BigIntegerField)(e).Set(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return (*orm.BigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return (*orm.BigIntegerField)(e).FieldType() +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + return (*orm.BigIntegerField)(e).SetRaw(value) +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return (*orm.BigIntegerField)(e).RawValue() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField orm.PositiveSmallIntegerField + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return orm.PositiveSmallIntegerField(e).Value() +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + (*orm.PositiveSmallIntegerField)(e).Set(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return (*orm.PositiveSmallIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return (*orm.PositiveSmallIntegerField)(e).FieldType() +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveSmallIntegerField)(e).SetRaw(value) +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return (*orm.PositiveSmallIntegerField)(e).RawValue() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField orm.PositiveIntegerField + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return orm.PositiveIntegerField(e).Value() +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + (*orm.PositiveIntegerField)(e).Set(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return (*orm.PositiveIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return (*orm.PositiveIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveIntegerField)(e).SetRaw(value) +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return (*orm.PositiveIntegerField)(e).RawValue() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField orm.PositiveBigIntegerField + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return orm.PositiveBigIntegerField(e).Value() +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + (*orm.PositiveBigIntegerField)(e).Set(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return (*orm.PositiveBigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return (*orm.PositiveBigIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveBigIntegerField)(e).SetRaw(value) +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return (*orm.PositiveBigIntegerField)(e).RawValue() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField orm.TextField + +// Value return TextField value +func (e TextField) Value() string { + return orm.TextField(e).Value() +} + +// Set the TextField value +func (e *TextField) Set(d string) { + (*orm.TextField)(e).Set(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return (*orm.TextField)(e).String() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return (*orm.TextField)(e).FieldType() +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + return (*orm.TextField)(e).SetRaw(value) +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return (*orm.TextField)(e).RawValue() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField orm.JSONField + +// Value return JSONField value +func (j JSONField) Value() string { + return orm.JSONField(j).Value() +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + (*orm.JSONField)(j).Set(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return (*orm.JSONField)(j).String() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return (*orm.JSONField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + return (*orm.JSONField)(j).SetRaw(value) +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return (*orm.JSONField)(j).RawValue() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField orm.JsonbField + +// Value return JsonbField value +func (j JsonbField) Value() string { + return orm.JsonbField(j).Value() +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + (*orm.JsonbField)(j).Set(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return (*orm.JsonbField)(j).String() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return (*orm.JsonbField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + return (*orm.JsonbField)(j).SetRaw(value) +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return (*orm.JsonbField)(j).RawValue() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/pkg/adapter/orm/orm.go b/pkg/adapter/orm/orm.go new file mode 100644 index 00000000..f8463ea2 --- /dev/null +++ b/pkg/adapter/orm/orm.go @@ -0,0 +1,314 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/astaxie/beego/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.Name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +// +// more docs: http://beego.me/docs/mvc/model/overview.md +package orm + +import ( + "context" + "database/sql" + "errors" + + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = orm.Debug + DebugLog = orm.DebugLog + DefaultRowsLimit = orm.DefaultRowsLimit + DefaultRelsDepth = orm.DefaultRelsDepth + DefaultTimeLoc = orm.DefaultTimeLoc + ErrTxHasBegan = errors.New(" transaction already begin") + ErrTxDone = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") +) + +type ormer struct { + delegate orm.Ormer + txDelegate orm.TxOrmer + isTx bool +} + +var _ Ormer = new(ormer) + +// read data to model +func (o *ormer) Read(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.Read(md, cols...) + } + return o.delegate.Read(md, cols...) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *ormer) ReadForUpdate(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.ReadForUpdate(md, cols...) + } + return o.delegate.ReadForUpdate(md, cols...) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *ormer) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + if o.isTx { + return o.txDelegate.ReadOrCreate(md, col1, cols...) + } + return o.delegate.ReadOrCreate(md, col1, cols...) +} + +// insert model data to database +func (o *ormer) Insert(md interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.Insert(md) + } + return o.delegate.Insert(md) +} + +// insert some models to database +func (o *ormer) InsertMulti(bulk int, mds interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.InsertMulti(bulk, mds) + } + return o.delegate.InsertMulti(bulk, mds) +} + +// InsertOrUpdate data to database +func (o *ormer) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + if o.isTx { + return o.txDelegate.InsertOrUpdate(md, colConflitAndArgs...) + } + return o.delegate.InsertOrUpdate(md, colConflitAndArgs...) +} + +// update model to database. +// cols set the columns those want to update. +func (o *ormer) Update(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Update(md, cols...) + } + return o.delegate.Update(md, cols...) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *ormer) Delete(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Delete(md, cols...) + } + return o.delegate.Delete(md, cols...) +} + +// create a models to models queryer +func (o *ormer) QueryM2M(md interface{}, name string) QueryM2Mer { + if o.isTx { + return o.txDelegate.QueryM2M(md, name) + } + return o.delegate.QueryM2M(md, name) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *ormer) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + kvs := make([]utils.KV, 0, 4) + for i, arg := range args { + switch i { + case 0: + if v, ok := arg.(bool); ok { + if v { + kvs = append(kvs, hints.DefaultRelDepth()) + } + } else if v, ok := arg.(int); ok { + kvs = append(kvs, hints.RelDepth(v)) + } + case 1: + kvs = append(kvs, hints.Limit(orm.ToInt64(arg))) + case 2: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + case 3: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + } + } + if o.isTx { + return o.txDelegate.LoadRelated(md, name, kvs...) + } + return o.delegate.LoadRelated(md, name, kvs...) +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *ormer) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + if o.isTx { + return o.txDelegate.QueryTable(ptrStructOrTableName) + } + return o.delegate.QueryTable(ptrStructOrTableName) +} + +// switch to another registered database driver by given name. +func (o *ormer) Using(name string) error { + if o.isTx { + return ErrTxHasBegan + } + o.delegate = orm.NewOrmUsingDB(name) + return nil +} + +// begin transaction +func (o *ormer) Begin() error { + if o.isTx { + return ErrTxHasBegan + } + return o.BeginTx(context.Background(), nil) +} + +func (o *ormer) BeginTx(ctx context.Context, opts *sql.TxOptions) error { + if o.isTx { + return ErrTxHasBegan + } + txOrmer, err := o.delegate.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + o.txDelegate = txOrmer + o.isTx = true + return nil +} + +// commit transaction +func (o *ormer) Commit() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Commit() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// rollback transaction +func (o *ormer) Rollback() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Rollback() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// return a raw query seter for raw sql string. +func (o *ormer) Raw(query string, args ...interface{}) RawSeter { + if o.isTx { + return o.txDelegate.Raw(query, args...) + } + return o.delegate.Raw(query, args...) +} + +// return current using database Driver +func (o *ormer) Driver() Driver { + if o.isTx { + return o.txDelegate.Driver() + } + return o.delegate.Driver() +} + +// return sql.DBStats for current database +func (o *ormer) DBStats() *sql.DBStats { + if o.isTx { + return o.txDelegate.DBStats() + } + return o.delegate.DBStats() +} + +// NewOrm create new orm +func NewOrm() Ormer { + o := orm.NewOrm() + return &ormer{ + delegate: o, + } +} + +// NewOrmWithDB create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + o, err := orm.NewOrmWithDB(driverName, aliasName, db) + if err != nil { + return nil, err + } + return &ormer{ + delegate: o, + }, nil +} diff --git a/pkg/adapter/orm/orm_conds.go b/pkg/adapter/orm/orm_conds.go new file mode 100644 index 00000000..986b4858 --- /dev/null +++ b/pkg/adapter/orm/orm_conds.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ExprSep define the expression separation +const ( + ExprSep = "__" +) + +// Condition struct. +// work for WHERE conditions. +type Condition orm.Condition + +// NewCondition return new condition struct +func NewCondition() *Condition { + return (*Condition)(orm.NewCondition()) +} + +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + return (*Condition)((orm.Condition)(c).Raw(expr, sql)) +} + +// And add expression to condition +func (c Condition) And(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).And(expr, args...)) +} + +// AndNot add NOT expression to condition +func (c Condition) AndNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).AndNot(expr, args...)) +} + +// AndCond combine a condition to current condition +func (c *Condition) AndCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndCond((*orm.Condition)(cond))) +} + +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndNotCond((*orm.Condition)(cond))) +} + +// Or add OR expression to condition +func (c Condition) Or(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).Or(expr, args...)) +} + +// OrNot add OR NOT expression to condition +func (c Condition) OrNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).OrNot(expr, args...)) +} + +// OrCond combine a OR condition to current condition +func (c *Condition) OrCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrCond((*orm.Condition)(cond))) +} + +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrNotCond((*orm.Condition)(cond))) +} + +// IsEmpty check the condition arguments are empty or not. +func (c *Condition) IsEmpty() bool { + return (*orm.Condition)(c).IsEmpty() +} diff --git a/pkg/adapter/orm/orm_log.go b/pkg/adapter/orm/orm_log.go new file mode 100644 index 00000000..6b2b4a9b --- /dev/null +++ b/pkg/adapter/orm/orm_log.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "io" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Log implement the log.Logger +type Log orm.Log + +// costomer log func +var LogFunc = orm.LogFunc + +// NewLog set io.Writer to create a Logger. +func NewLog(out io.Writer) *Log { + return (*Log)(orm.NewLog(out)) +} diff --git a/pkg/adapter/orm/orm_queryset.go b/pkg/adapter/orm/orm_queryset.go new file mode 100644 index 00000000..5f211644 --- /dev/null +++ b/pkg/adapter/orm/orm_queryset.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// define Col operations +const ( + ColAdd = orm.ColAdd + ColMinus = orm.ColMinus + ColMultiply = orm.ColMultiply + ColExcept = orm.ColExcept + ColBitAnd = orm.ColBitAnd + ColBitRShift = orm.ColBitRShift + ColBitLShift = orm.ColBitLShift + ColBitXOR = orm.ColBitXOR + ColBitOr = orm.ColBitOr +) diff --git a/pkg/adapter/orm/qb.go b/pkg/adapter/orm/qb.go new file mode 100644 index 00000000..90b97797 --- /dev/null +++ b/pkg/adapter/orm/qb.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// QueryBuilder is the Query builder interface +type QueryBuilder orm.QueryBuilder + +// NewQueryBuilder return the QueryBuilder +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + return orm.NewQueryBuilder(driver) +} diff --git a/pkg/adapter/orm/qb_mysql.go b/pkg/adapter/orm/qb_mysql.go new file mode 100644 index 00000000..9566068f --- /dev/null +++ b/pkg/adapter/orm/qb_mysql.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// CommaSpace is the separation +const CommaSpace = orm.CommaSpace + +// MySQLQueryBuilder is the SQL build +type MySQLQueryBuilder orm.MySQLQueryBuilder + +// Select will join the fields +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.MySQLQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *MySQLQueryBuilder) String() string { + return (*orm.MySQLQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/qb_tidb.go b/pkg/adapter/orm/qb_tidb.go new file mode 100644 index 00000000..05c91a26 --- /dev/null +++ b/pkg/adapter/orm/qb_tidb.go @@ -0,0 +1,147 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder orm.TiDBQueryBuilder + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.TiDBQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return (*orm.TiDBQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/query_setter_adapter.go b/pkg/adapter/orm/query_setter_adapter.go new file mode 100644 index 00000000..cc24ef6b --- /dev/null +++ b/pkg/adapter/orm/query_setter_adapter.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +type baseQuerySetter struct { +} + +func (b *baseQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} diff --git a/pkg/adapter/orm/types.go b/pkg/adapter/orm/types.go new file mode 100644 index 00000000..3372e301 --- /dev/null +++ b/pkg/adapter/orm/types.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Params stores the Params +type Params orm.Params + +// ParamsList stores paramslist +type ParamsList orm.ParamsList + +// Driver define database driver +type Driver orm.Driver + +// Fielder define field info +type Fielder orm.Fielder + +// Ormer define the orm interface +type Ormer interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will set user's pk field + Insert(interface{}) (int64, error) + // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}, cols ...string) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + // args[0] bool true useDefaultRelsDepth ; false depth 0 + // args[0] int loadRelationDepth + // args[1] int limit default limit 1000 + // args[2] int offset default offset 0 + // args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() + Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error + // commit transaction + Commit() error + // rollback transaction + Rollback() error + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + Driver() Driver + DBStats() *sql.DBStats +} + +// Inserter insert prepared statement +type Inserter orm.Inserter + +// QuerySeter query seter +type QuerySeter orm.QuerySeter + +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer orm.QueryM2Mer + +// RawPreparer raw query statement +type RawPreparer orm.RawPreparer + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter orm.RawSeter diff --git a/pkg/adapter/orm/utils.go b/pkg/adapter/orm/utils.go new file mode 100644 index 00000000..16d0e4e5 --- /dev/null +++ b/pkg/adapter/orm/utils.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + +// StrTo is the target string +type StrTo orm.StrTo + +// Set string +func (f *StrTo) Set(v string) { + (*orm.StrTo)(f).Set(v) +} + +// Clear string +func (f *StrTo) Clear() { + (*orm.StrTo)(f).Clear() +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return orm.StrTo(f).Exist() +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return orm.StrTo(f).Bool() +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + return orm.StrTo(f).Float32() +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return orm.StrTo(f).Float64() +} + +// Int string to int +func (f StrTo) Int() (int, error) { + return orm.StrTo(f).Int() +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + return orm.StrTo(f).Int8() +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + return orm.StrTo(f).Int16() +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + return orm.StrTo(f).Int32() +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + return orm.StrTo(f).Int64() +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + return orm.StrTo(f).Uint() +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + return orm.StrTo(f).Uint8() +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + return orm.StrTo(f).Uint16() +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + return orm.StrTo(f).Uint32() +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + return orm.StrTo(f).Uint64() +} + +// String string to string +func (f StrTo) String() string { + return orm.StrTo(f).String() +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// snake string, XxYy to xx_yy , XxYY to xx_y_y +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + +// camel string, xx_yy to XxYy +func camelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data[:]) +} + +type argString []string + +// get string by index from string slice +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// parse time to string with location +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// get pointer indirect type +func indirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return indirectType(v.Elem()) + default: + return v + } +} diff --git a/pkg/adapter/orm/utils_test.go b/pkg/adapter/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/pkg/adapter/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/pkg/adapter/plugins/apiauth/apiauth.go b/pkg/adapter/plugins/apiauth/apiauth.go new file mode 100644 index 00000000..ed43f8a0 --- /dev/null +++ b/pkg/adapter/plugins/apiauth/apiauth.go @@ -0,0 +1,94 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package apiauth provides handlers to enable apiauth support. +// +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/apiauth" +// ) +// +// func main(){ +// // apiauth every request +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.Run() +// } +// +// Advanced Usage: +// +// func getAppSecret(appid string) string { +// // get appsecret by appid +// // maybe store in configure, maybe in database +// } +// +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) +// +// Information: +// +// In the request user should include these params in the query +// +// 1. appid +// +// appid is assigned to the application +// +// 2. signature +// +// get the signature use apiauth.Signature() +// +// when you send to server remember use url.QueryEscape() +// +// 3. timestamp: +// +// send the request time, the format is yyyy-mm-dd HH:ii:ss +// +package apiauth + +import ( + "net/url" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/apiauth" +) + +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret apiauth.AppIDToAppSecret + +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { + f := apiauth.APIBasicAuth(appid, appkey) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { + ft := apiauth.APISecretAuth(apiauth.AppIDToAppSecret(f), timeout) + return func(ctx *context.Context) { + ft((*beecontext.Context)(ctx)) + } +} + +// Signature used to generate signature with the appsecret/method/params/RequestURI +func Signature(appsecret, method string, params url.Values, requestURL string) string { + return apiauth.Signature(appsecret, method, params, requestURL) +} diff --git a/pkg/adapter/plugins/apiauth/apiauth_test.go b/pkg/adapter/plugins/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/pkg/adapter/plugins/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/pkg/adapter/plugins/auth/basic.go b/pkg/adapter/plugins/auth/basic.go new file mode 100644 index 00000000..7a9cd326 --- /dev/null +++ b/pkg/adapter/plugins/auth/basic.go @@ -0,0 +1,81 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package auth provides handlers to enable basic auth support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/auth" +// ) +// +// func main(){ +// // authenticate every request +// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func SecretAuth(username, password string) bool { +// return username == "astaxie" && password == "helloBeego" +// } +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") +// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) +package auth + +import ( + "net/http" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/auth" +) + +// Basic is the http basic auth +func Basic(username string, password string) beego.FilterFunc { + return func(c *context.Context) { + f := auth.Basic(username, password) + f((*beecontext.Context)(c)) + } +} + +// NewBasicAuthenticator return the BasicAuth +func NewBasicAuthenticator(secrets SecretProvider, realm string) beego.FilterFunc { + f := auth.NewBasicAuthenticator(auth.SecretProvider(secrets), realm) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// SecretProvider is the SecretProvider function +type SecretProvider auth.SecretProvider + +// BasicAuth store the SecretProvider and Realm +type BasicAuth auth.BasicAuth + +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries +func (a *BasicAuth) CheckAuth(r *http.Request) string { + return (*auth.BasicAuth)(a).CheckAuth(r) +} + +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). +func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { + (*auth.BasicAuth)(a).RequireAuth(w, r) +} diff --git a/pkg/adapter/plugins/authz/authz.go b/pkg/adapter/plugins/authz/authz.go new file mode 100644 index 00000000..c38be9cb --- /dev/null +++ b/pkg/adapter/plugins/authz/authz.go @@ -0,0 +1,80 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/casbin/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "net/http" + + "github.com/casbin/casbin" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/authz" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + f := authz.NewAuthorizer(e) + return func(context *context.Context) { + f((*beecontext.Context)(context)) + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer authz.BasicAuthorizer + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + return (*authz.BasicAuthorizer)(a).GetUserName(r) +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + return (*authz.BasicAuthorizer)(a).CheckPermission(r) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + (*authz.BasicAuthorizer)(a).RequirePermission(w) +} diff --git a/pkg/adapter/plugins/authz/authz_model.conf b/pkg/adapter/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/pkg/adapter/plugins/authz/authz_policy.csv b/pkg/adapter/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/pkg/adapter/plugins/authz/authz_test.go b/pkg/adapter/plugins/authz/authz_test.go new file mode 100644 index 00000000..ddbda5f4 --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_test.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "net/http" + "net/http/httptest" + "testing" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/plugins/auth" + "github.com/casbin/casbin" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/pkg/adapter/plugins/cors/cors.go b/pkg/adapter/plugins/cors/cors.go new file mode 100644 index 00000000..65af8b8f --- /dev/null +++ b/pkg/adapter/plugins/cors/cors.go @@ -0,0 +1,71 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cors provides handlers to enable CORS support. +// Usage +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/cors" +// ) +// +// func main() { +// // CORS for https://foo.* origins, allowing: +// // - PUT and PATCH methods +// // - Origin header +// // - Credentials share +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ +// AllowOrigins: []string{"https://*.foo.com"}, +// AllowMethods: []string{"PUT", "PATCH"}, +// AllowHeaders: []string{"Origin"}, +// ExposeHeaders: []string{"Content-Length"}, +// AllowCredentials: true, +// })) +// beego.Run() +// } +package cors + +import ( + beego "github.com/astaxie/beego/pkg/adapter" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/cors" + + "github.com/astaxie/beego/pkg/adapter/context" +) + +// Options represents Access Control options. +type Options cors.Options + +// Header converts options into CORS headers. +func (o *Options) Header(origin string) (headers map[string]string) { + return (*cors.Options)(o).Header(origin) +} + +// PreflightHeader converts options into CORS headers for a preflight response. +func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { + return (*cors.Options)(o).PreflightHeader(origin, rMethod, rHeaders) +} + +// IsOriginAllowed looks up if the origin matches one of the patterns +// generated from Options.AllowOrigins patterns. +func (o *Options) IsOriginAllowed(origin string) bool { + return (*cors.Options)(o).IsOriginAllowed(origin) +} + +// Allow enables CORS for requests those match the provided options. +func Allow(opts *Options) beego.FilterFunc { + f := cors.Allow((*cors.Options)(opts)) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} diff --git a/pkg/adapter/policy.go b/pkg/adapter/policy.go new file mode 100644 index 00000000..f3759c76 --- /dev/null +++ b/pkg/adapter/policy.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindPolicy Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + pf := (*web.ControllerRegister)(p).FindPolicy((*beecontext.Context)(cont)) + npf := newToOldPolicyFunc(pf) + return npf +} + +func newToOldPolicyFunc(pf []web.PolicyFunc) []PolicyFunc { + npf := make([]PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *context.Context) { + f((*beecontext.Context)(c)) + }) + } + return npf +} + +func oldToNewPolicyFunc(pf []PolicyFunc) []web.PolicyFunc { + npf := make([]web.PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *beecontext.Context) { + f((*context.Context)(c)) + }) + } + return npf +} + +// Policy Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + pf := oldToNewPolicyFunc(policy) + web.Policy(pattern, method, pf...) +} diff --git a/pkg/adapter/router.go b/pkg/adapter/router.go new file mode 100644 index 00000000..8e8d9fdb --- /dev/null +++ b/pkg/adapter/router.go @@ -0,0 +1,282 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "time" + + beecontext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// default filter execution points +const ( + BeforeStatic = web.BeforeStatic + BeforeRouter = web.BeforeRouter + BeforeExec = web.BeforeExec + AfterExec = web.AfterExec + FinishRouter = web.FinishRouter +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = web.HTTPMETHOD + + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &newToOldFtHdlAdapter{ + delegate: web.DefaultAccessLogFilter, + } +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +type newToOldFtHdlAdapter struct { + delegate web.FilterHandler +} + +func (n *newToOldFtHdlAdapter) Filter(ctx *beecontext.Context) bool { + return n.delegate.Filter((*context.Context)(ctx)) +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + web.ExceptMethodAppend(action) +} + +// ControllerInfo holds information about the controller. +type ControllerInfo web.ControllerInfo + +func (c *ControllerInfo) GetPattern() string { + return (*web.ControllerInfo)(c).GetPattern() +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister web.ControllerRegister + +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + return (*ControllerRegister)(web.NewControllerRegister()) +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + nls := oldToNewCtrlIntfs(cList) + (*web.ControllerRegister)(p).Include(nls...) +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return (*beecontext.Context)((*web.ControllerRegister)(p).GetContext()) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + (*web.ControllerRegister)(p).GiveBackContext((*context.Context)(ctx)) +} + +// Get add get method +// usage: +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Get(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Post add post method +// usage: +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Post(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Put add put method +// usage: +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Put(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Delete add delete method +// usage: +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Delete(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Head add head method +// usage: +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Head(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Patch add patch method +// usage: +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Patch(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Options add options method +// usage: +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Options(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Any add all method +// usage: +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Any(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// AddMethod add http method router +// usage: +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).AddMethod(method, pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + (*web.ControllerRegister)(p).Handler(pattern, h, options) +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainContorlller{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + (*web.ControllerRegister)(p).AddAuto(c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + (*web.ControllerRegister)(p).AddAutoPrefix(prefix, c) +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + opts := oldToNewFilterOpts(params) + return (*web.ControllerRegister)(p).InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*beecontext.Context)(ctx)) + }, opts...) +} + +func oldToNewFilterOpts(params []bool) []web.FilterOpt { + opts := make([]web.FilterOpt, 0, 4) + if len(params) > 0 { + opts = append(opts, web.WithReturnOnOutput(params[0])) + } else { + // the default value should be true + opts = append(opts, web.WithReturnOnOutput(true)) + } + if len(params) > 1 { + opts = append(opts, web.WithResetParams(params[1])) + } + return opts +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + return (*web.ControllerRegister)(p).URLFor(endpoint, values...) +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + (*web.ControllerRegister)(p).ServeHTTP(rw, r) +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(ctx *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + r, ok := (*web.ControllerRegister)(p).FindRouter((*context.Context)(ctx)) + return (*ControllerInfo)(r), ok +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + web.LogAccess((*context.Context)(ctx), startTime, statusCode) +} diff --git a/pkg/adapter/session/couchbase/sess_couchbase.go b/pkg/adapter/session/couchbase/sess_couchbase.go new file mode 100644 index 00000000..bce09641 --- /dev/null +++ b/pkg/adapter/session/couchbase/sess_couchbase.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package couchbase for session provider +// +// depend on github.com/couchbaselabs/go-couchbasee +// +// go install github.com/couchbaselabs/go-couchbase +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/couchbase" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package couchbase + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beecb "github.com/astaxie/beego/pkg/infrastructure/session/couchbase" +) + +// SessionStore store each session +type SessionStore beecb.SessionStore + +// Provider couchabse provided +type Provider beecb.Provider + +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { + return (*beecb.SessionStore)(cs).Set(context.Background(), key, value) +} + +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { + return (*beecb.SessionStore)(cs).Get(context.Background(), key) +} + +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { + return (*beecb.SessionStore)(cs).Delete(context.Background(), key) +} + +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { + return (*beecb.SessionStore)(cs).Flush(context.Background()) +} + +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { + return (*beecb.SessionStore)(cs).SessionID(context.Background()) +} + +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beecb.SessionStore)(cs).SessionRelease(context.Background(), w) +} + +// SessionInit init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beecb.Provider)(cp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { + res, _ := (*beecb.Provider)(cp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { + return (*beecb.Provider)(cp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle +func (cp *Provider) SessionGC() { + (*beecb.Provider)(cp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (cp *Provider) SessionAll() int { + return (*beecb.Provider)(cp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/ledis/ledis_session.go b/pkg/adapter/session/ledis/ledis_session.go new file mode 100644 index 00000000..96198837 --- /dev/null +++ b/pkg/adapter/session/ledis/ledis_session.go @@ -0,0 +1,86 @@ +// Package ledis provide session Provider +package ledis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beeLedis "github.com/astaxie/beego/pkg/infrastructure/session/ledis" +) + +// SessionStore ledis session store +type SessionStore beeLedis.SessionStore + +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { + return (*beeLedis.SessionStore)(ls).Set(context.Background(), key, value) +} + +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { + return (*beeLedis.SessionStore)(ls).Get(context.Background(), key) +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { + return (*beeLedis.SessionStore)(ls).Delete(context.Background(), key) +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { + return (*beeLedis.SessionStore)(ls).Flush(context.Background()) +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { + return (*beeLedis.SessionStore)(ls).SessionID(context.Background()) +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeLedis.SessionStore)(ls).SessionRelease(context.Background(), w) +} + +// Provider ledis session provider +type Provider beeLedis.Provider + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeLedis.Provider)(lp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { + res, _ := (*beeLedis.Provider)(lp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { + return (*beeLedis.Provider)(lp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { + (*beeLedis.Provider)(lp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (lp *Provider) SessionAll() int { + return (*beeLedis.Provider)(lp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/memcache/sess_memcache.go b/pkg/adapter/session/memcache/sess_memcache.go new file mode 100644 index 00000000..8afa79aa --- /dev/null +++ b/pkg/adapter/session/memcache/sess_memcache.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for session provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/memcache" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package memcache + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beemem "github.com/astaxie/beego/pkg/infrastructure/session/memcache" +) + +// SessionStore memcache session store +type SessionStore beemem.SessionStore + +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beemem.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beemem.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beemem.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { + return (*beemem.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { + return (*beemem.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beemem.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// MemProvider memcache session provider +type MemProvider beemem.MemProvider + +// SessionInit init memcache session +// savepath like +// e.g. 127.0.0.1:9090 +func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*beemem.MemProvider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check memcache session exist by sid +func (rp *MemProvider) SessionExist(sid string) bool { + res, _ := (*beemem.MemProvider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete memcache session by id +func (rp *MemProvider) SessionDestroy(sid string) error { + return (*beemem.MemProvider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *MemProvider) SessionGC() { + (*beemem.MemProvider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *MemProvider) SessionAll() int { + return (*beemem.MemProvider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/mysql/sess_mysql.go b/pkg/adapter/session/mysql/sess_mysql.go new file mode 100644 index 00000000..1850a380 --- /dev/null +++ b/pkg/adapter/session/mysql/sess_mysql.go @@ -0,0 +1,135 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql for session provider +// +// depends on github.com/go-sql-driver/mysql: +// +// go install github.com/go-sql-driver/mysql +// +// mysql session support need create table as sql: +// CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/mysql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package mysql + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/pkg/infrastructure/session/mysql" + + // import mysql driver + _ "github.com/go-sql-driver/mysql" +) + +var ( + // TableName store the session in MySQL + TableName = mysql.TableName + mysqlpder = &Provider{} +) + +// SessionStore mysql session store +type SessionStore mysql.SessionStore + +// Set value in mysql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*mysql.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*mysql.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { + return (*mysql.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { + return (*mysql.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { + return (*mysql.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save mysql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*mysql.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider mysql session provider +type Provider mysql.Provider + +// SessionInit init mysql session. +// savepath is the connection string of mysql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*mysql.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*mysql.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*mysql.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { + (*mysql.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { + return (*mysql.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/postgres/sess_postgresql.go b/pkg/adapter/session/postgres/sess_postgresql.go new file mode 100644 index 00000000..de1adbc4 --- /dev/null +++ b/pkg/adapter/session/postgres/sess_postgresql.go @@ -0,0 +1,139 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres for session provider +// +// depends on github.com/lib/pq: +// +// go install github.com/lib/pq +// +// +// needs this table in your database: +// +// CREATE TABLE session ( +// session_key char(64) NOT NULL, +// session_data bytea, +// session_expiry timestamp NOT NULL, +// CONSTRAINT session_key PRIMARY KEY(session_key) +// ); +// +// will be activated with these settings in app.conf: +// +// SessionOn = true +// SessionProvider = postgresql +// SessionSavePath = "user=a password=b dbname=c sslmode=disable" +// SessionName = session +// +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/postgresql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package postgres + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + // import postgresql Driver + _ "github.com/lib/pq" + + "github.com/astaxie/beego/pkg/infrastructure/session/postgres" +) + +// SessionStore postgresql session store +type SessionStore postgres.SessionStore + +// Set value in postgresql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*postgres.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*postgres.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { + return (*postgres.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { + return (*postgres.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { + return (*postgres.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save postgresql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*postgres.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider postgresql session provider +type Provider postgres.Provider + +// SessionInit init postgresql session. +// savepath is the connection string of postgresql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*postgres.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*postgres.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*postgres.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { + (*postgres.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { + return (*postgres.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/provider_adapter.go b/pkg/adapter/session/provider_adapter.go new file mode 100644 index 00000000..11177a4d --- /dev/null +++ b/pkg/adapter/session/provider_adapter.go @@ -0,0 +1,104 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type oldToNewProviderAdapter struct { + delegate Provider +} + +func (o *oldToNewProviderAdapter) SessionInit(ctx context.Context, gclifetime int64, config string) error { + return o.delegate.SessionInit(gclifetime, config) +} + +func (o *oldToNewProviderAdapter) SessionRead(ctx context.Context, sid string) (session.Store, error) { + store, err := o.delegate.SessionRead(sid) + return &oldToNewStoreAdapter{ + delegate: store, + }, err +} + +func (o *oldToNewProviderAdapter) SessionExist(ctx context.Context, sid string) (bool, error) { + return o.delegate.SessionExist(sid), nil +} + +func (o *oldToNewProviderAdapter) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + s, err := o.delegate.SessionRegenerate(oldsid, sid) + return &oldToNewStoreAdapter{ + delegate: s, + }, err +} + +func (o *oldToNewProviderAdapter) SessionDestroy(ctx context.Context, sid string) error { + return o.delegate.SessionDestroy(sid) +} + +func (o *oldToNewProviderAdapter) SessionAll(ctx context.Context) int { + return o.delegate.SessionAll() +} + +func (o *oldToNewProviderAdapter) SessionGC(ctx context.Context) { + o.delegate.SessionGC() +} + +type newToOldProviderAdapter struct { + delegate session.Provider +} + +func (n *newToOldProviderAdapter) SessionInit(gclifetime int64, config string) error { + return n.delegate.SessionInit(context.Background(), gclifetime, config) +} + +func (n *newToOldProviderAdapter) SessionRead(sid string) (Store, error) { + s, err := n.delegate.SessionRead(context.Background(), sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionExist(sid string) bool { + res, _ := n.delegate.SessionExist(context.Background(), sid) + return res +} + +func (n *newToOldProviderAdapter) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := n.delegate.SessionRegenerate(context.Background(), oldsid, sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionDestroy(sid string) error { + return n.delegate.SessionDestroy(context.Background(), sid) +} + +func (n *newToOldProviderAdapter) SessionAll() int { + return n.delegate.SessionAll(context.Background()) +} + +func (n *newToOldProviderAdapter) SessionGC() { + n.delegate.SessionGC(context.Background()) +} diff --git a/pkg/adapter/session/redis/sess_redis.go b/pkg/adapter/session/redis/sess_redis.go new file mode 100644 index 00000000..6c521e50 --- /dev/null +++ b/pkg/adapter/session/redis/sess_redis.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeRedis "github.com/astaxie/beego/pkg/infrastructure/session/redis" +) + +// MaxPoolSize redis max pool size +var MaxPoolSize = beeRedis.MaxPoolSize + +// SessionStore redis session store +type SessionStore beeRedis.SessionStore + +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beeRedis.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beeRedis.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beeRedis.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { + return (*beeRedis.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { + return (*beeRedis.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeRedis.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis session provider +type Provider beeRedis.Provider + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeRedis.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*beeRedis.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*beeRedis.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*beeRedis.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*beeRedis.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_cluster/redis_cluster.go b/pkg/adapter/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..03a805e4 --- /dev/null +++ b/pkg/adapter/session/redis_cluster/redis_cluster.go @@ -0,0 +1,120 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + cluster "github.com/astaxie/beego/pkg/infrastructure/session/redis_cluster" +) + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = cluster.MaxPoolSize + +// SessionStore redis_cluster session store +type SessionStore cluster.SessionStore + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*cluster.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*cluster.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + return (*cluster.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + return (*cluster.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return (*cluster.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*cluster.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_cluster session provider +type Provider cluster.Provider + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*cluster.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*cluster.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*cluster.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*cluster.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*cluster.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..f5eb8a4f --- /dev/null +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + sentinel "github.com/astaxie/beego/pkg/infrastructure/session/redis_sentinel" +) + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = sentinel.DefaultPoolSize + +// SessionStore redis_sentinel session store +type SessionStore sentinel.SessionStore + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*sentinel.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*sentinel.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + return (*sentinel.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + return (*sentinel.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return (*sentinel.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*sentinel.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_sentinel session provider +type Provider sentinel.Provider + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*sentinel.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*sentinel.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*sentinel.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*sentinel.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*sentinel.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..7c33985f --- /dev/null +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/pkg/adapter/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + // todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) + +} diff --git a/pkg/adapter/session/sess_cookie.go b/pkg/adapter/session/sess_cookie.go new file mode 100644 index 00000000..32216040 --- /dev/null +++ b/pkg/adapter/session/sess_cookie.go @@ -0,0 +1,114 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// CookieSessionStore Cookie SessionStore +type CookieSessionStore session.CookieSessionStore + +// Set value to cookie session. +// the value are encoded as gob with hash block string. +func (st *CookieSessionStore) Set(key, value interface{}) error { + return (*session.CookieSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from cookie session +func (st *CookieSessionStore) Get(key interface{}) interface{} { + return (*session.CookieSessionStore)(st).Get(context.Background(), key) +} + +// Delete value in cookie session +func (st *CookieSessionStore) Delete(key interface{}) error { + return (*session.CookieSessionStore)(st).Delete(context.Background(), key) +} + +// Flush Clean all values in cookie session +func (st *CookieSessionStore) Flush() error { + return (*session.CookieSessionStore)(st).Flush(context.Background()) +} + +// SessionID Return id of this cookie session +func (st *CookieSessionStore) SessionID() string { + return (*session.CookieSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Write cookie session to http response cookie +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.CookieSessionStore)(st).SessionRelease(context.Background(), w) +} + +// CookieProvider Cookie session provider +type CookieProvider session.CookieProvider + +// SessionInit Init cookie session provider with max lifetime and config json. +// maxlifetime is ignored. +// json config: +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + return (*session.CookieProvider)(pder).SessionInit(context.Background(), maxlifetime, config) +} + +// SessionRead Get SessionStore in cooke. +// decode cooke string to map and put into SessionStore with sid. +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Cookie session is always existed +func (pder *CookieProvider) SessionExist(sid string) bool { + res, _ := (*session.CookieProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Implement method, no used. +func (pder *CookieProvider) SessionDestroy(sid string) error { + return (*session.CookieProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC Implement method, no used. +func (pder *CookieProvider) SessionGC() { + (*session.CookieProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll Implement method, return 0. +func (pder *CookieProvider) SessionAll() int { + return (*session.CookieProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate Implement method, no used. +func (pder *CookieProvider) SessionUpdate(sid string) error { + return (*session.CookieProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/adapter/session/sess_cookie_test.go b/pkg/adapter/session/sess_cookie_test.go new file mode 100644 index 00000000..b6726005 --- /dev/null +++ b/pkg/adapter/session/sess_cookie_test.go @@ -0,0 +1,105 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID() != session.SessionID() { + t.Fatal("get cookie session id is not the same again.") + } + + // After destroy session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID() == session.SessionID() { + t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") + } +} diff --git a/pkg/adapter/session/sess_file.go b/pkg/adapter/session/sess_file.go new file mode 100644 index 00000000..b9648998 --- /dev/null +++ b/pkg/adapter/session/sess_file.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// FileSessionStore File session store +type FileSessionStore session.FileSessionStore + +// Set value to file session +func (fs *FileSessionStore) Set(key, value interface{}) error { + return (*session.FileSessionStore)(fs).Set(context.Background(), key, value) +} + +// Get value from file session +func (fs *FileSessionStore) Get(key interface{}) interface{} { + return (*session.FileSessionStore)(fs).Get(context.Background(), key) +} + +// Delete value in file session by given key +func (fs *FileSessionStore) Delete(key interface{}) error { + return (*session.FileSessionStore)(fs).Delete(context.Background(), key) +} + +// Flush Clean all values in file session +func (fs *FileSessionStore) Flush() error { + return (*session.FileSessionStore)(fs).Flush(context.Background()) +} + +// SessionID Get file session store id +func (fs *FileSessionStore) SessionID() string { + return (*session.FileSessionStore)(fs).SessionID(context.Background()) +} + +// SessionRelease Write file session to local file with Gob string +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.FileSessionStore)(fs).SessionRelease(context.Background(), w) +} + +// FileProvider File session provider +type FileProvider session.FileProvider + +// SessionInit Init file session provider. +// savePath sets the session files path. +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.FileProvider)(fp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead Read file session by sid. +// if file is not exist, create it. +// the file path is generated from sid string. +func (fp *FileProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Check file session exist. +// it checks the file named from sid exist or not. +func (fp *FileProvider) SessionExist(sid string) bool { + res, _ := (*session.FileProvider)(fp).SessionExist(context.Background(), sid) + return res +} + +// SessionDestroy Remove all files in this save path +func (fp *FileProvider) SessionDestroy(sid string) error { + return (*session.FileProvider)(fp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle files in save path +func (fp *FileProvider) SessionGC() { + (*session.FileProvider)(fp).SessionGC(context.Background()) +} + +// SessionAll Get active file session number. +// it walks save path to count files. +func (fp *FileProvider) SessionAll() int { + return (*session.FileProvider)(fp).SessionAll(context.Background()) +} + +// SessionRegenerate Generate new sid for file session. +// it delete old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} diff --git a/pkg/adapter/session/sess_file_test.go b/pkg/adapter/session/sess_file_test.go new file mode 100644 index 00000000..4c90a3ac --- /dev/null +++ b/pkg/adapter/session/sess_file_test.go @@ -0,0 +1,336 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} diff --git a/pkg/adapter/session/sess_mem.go b/pkg/adapter/session/sess_mem.go new file mode 100644 index 00000000..818c8329 --- /dev/null +++ b/pkg/adapter/session/sess_mem.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// MemSessionStore memory session store. +// it saved sessions in a map in memory. +type MemSessionStore session.MemSessionStore + +// Set value to memory session +func (st *MemSessionStore) Set(key, value interface{}) error { + return (*session.MemSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from memory session by key +func (st *MemSessionStore) Get(key interface{}) interface{} { + return (*session.MemSessionStore)(st).Get(context.Background(), key) +} + +// Delete in memory session by key +func (st *MemSessionStore) Delete(key interface{}) error { + return (*session.MemSessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in memory session +func (st *MemSessionStore) Flush() error { + return (*session.MemSessionStore)(st).Flush(context.Background()) +} + +// SessionID get this id of memory session store +func (st *MemSessionStore) SessionID() string { + return (*session.MemSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Implement method, no used. +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.MemSessionStore)(st).SessionRelease(context.Background(), w) +} + +// MemProvider Implement the provider interface +type MemProvider session.MemProvider + +// SessionInit init memory session +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.MemProvider)(pder).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist check session store exist in memory session by sid +func (pder *MemProvider) SessionExist(sid string) bool { + res, _ := (*session.MemProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy delete session store in memory session by id +func (pder *MemProvider) SessionDestroy(sid string) error { + return (*session.MemProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC clean expired session stores in memory session +func (pder *MemProvider) SessionGC() { + (*session.MemProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll get count number of memory session +func (pder *MemProvider) SessionAll() int { + return (*session.MemProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate expand time of session store by id in memory session +func (pder *MemProvider) SessionUpdate(sid string) error { + return (*session.MemProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/adapter/session/sess_mem_test.go b/pkg/adapter/session/sess_mem_test.go new file mode 100644 index 00000000..2e8934b8 --- /dev/null +++ b/pkg/adapter/session/sess_mem_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, _ := NewManager("memory", conf) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + defer sess.SessionRelease(w) + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/pkg/adapter/session/sess_test.go b/pkg/adapter/session/sess_test.go new file mode 100644 index 00000000..aba702ca --- /dev/null +++ b/pkg/adapter/session/sess_test.go @@ -0,0 +1,51 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "testing" +) + +func Test_gob(t *testing.T) { + a := make(map[interface{}]interface{}) + a["username"] = "astaxie" + a[12] = 234 + a["user"] = User{"asta", "xie"} + b, err := EncodeGob(a) + if err != nil { + t.Error(err) + } + c, err := DecodeGob(b) + if err != nil { + t.Error(err) + } + if len(c) == 0 { + t.Error("decodeGob empty") + } + if c["username"] != "astaxie" { + t.Error("decode string error") + } + if c[12] != 234 { + t.Error("decode int error") + } + if c["user"].(User).Username != "asta" { + t.Error("decode struct error") + } +} + +type User struct { + Username string + NickName string +} diff --git a/pkg/adapter/session/sess_utils.go b/pkg/adapter/session/sess_utils.go new file mode 100644 index 00000000..3d107198 --- /dev/null +++ b/pkg/adapter/session/sess_utils.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// EncodeGob encode the obj to gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + return session.EncodeGob(obj) +} + +// DecodeGob decode data to map +func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { + return session.DecodeGob(encoded) +} diff --git a/pkg/adapter/session/session.go b/pkg/adapter/session/session.go new file mode 100644 index 00000000..eea2f90e --- /dev/null +++ b/pkg/adapter/session/session.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session provider +// +// Usage: +// import( +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package session + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// Store contains all data for one session process with specific id. +type Store interface { + Set(key, value interface{}) error // set session value + Get(key interface{}) interface{} // get session value + Delete(key interface{}) error // delete session value + SessionID() string // back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error // delete all data +} + +// Provider contains global session methods and saved SessionStores. +// it can operate a SessionStore by its id. +type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (Store, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (Store, error) + SessionDestroy(sid string) error + SessionAll() int // get all active session + SessionGC() +} + +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + session.Register(name, &oldToNewProviderAdapter{ + delegate: provide, + }) +} + +// GetProvider +func GetProvider(name string) (Provider, error) { + res, err := session.GetProvider(name) + if adt, ok := res.(*oldToNewProviderAdapter); err == nil && ok { + return adt.delegate, err + } + + return &newToOldProviderAdapter{ + delegate: res, + }, err +} + +// ManagerConfig define the session config +type ManagerConfig session.ManagerConfig + +// Manager contains Provider and its configuration. +type Manager session.Manager + +// NewManager Create new Manager with provider name and json config string. +// provider name: +// 1. cookie +// 2. file +// 3. memory +// 4. redis +// 5. mysql +// json config: +// 1. is https default false +// 2. hashfunc default sha1 +// 3. hashkey default beegosessionkey +// 4. maxage default is none +func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { + m, err := session.NewManager(provideName, (*session.ManagerConfig)(cf)) + return (*Manager)(m), err +} + +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return &newToOldProviderAdapter{ + delegate: (*session.Manager)(manager).GetProvider(), + } +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (Store, error) { + s, err := (*session.Manager)(manager).SessionStart(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Destroy session by its id in http request cookie. +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + (*session.Manager)(manager).SessionDestroy(w, r) +} + +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (Store, error) { + s, err := (*session.Manager)(manager).GetSessionStore(sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// GC Start session gc process. +// it can do gc in times after gc lifetime. +func (manager *Manager) GC() { + (*session.Manager)(manager).GC() +} + +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) Store { + s := (*session.Manager)(manager).SessionRegenerateID(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +// GetActiveSession Get all active sessions count number. +func (manager *Manager) GetActiveSession() int { + return (*session.Manager)(manager).GetActiveSession() +} + +// SetSecure Set cookie with https. +func (manager *Manager) SetSecure(secure bool) { + (*session.Manager)(manager).SetSecure(secure) +} + +// Log implement the log.Logger +type Log session.Log + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + return (*Log)(session.NewSessionLog(out)) +} diff --git a/pkg/adapter/session/ssdb/sess_ssdb.go b/pkg/adapter/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..aee3a364 --- /dev/null +++ b/pkg/adapter/session/ssdb/sess_ssdb.go @@ -0,0 +1,84 @@ +package ssdb + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeSsdb "github.com/astaxie/beego/pkg/infrastructure/session/ssdb" +) + +// Provider holds ssdb client and configs +type Provider beeSsdb.Provider + +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { + return (*beeSsdb.Provider)(p).SessionInit(context.Background(), maxLifetime, savePath) +} + +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { + res, _ := (*beeSsdb.Provider)(p).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { + return (*beeSsdb.Provider)(p).SessionDestroy(context.Background(), sid) +} + +// SessionGC not implemented +func (p *Provider) SessionGC() { + (*beeSsdb.Provider)(p).SessionGC(context.Background()) +} + +// SessionAll not implemented +func (p *Provider) SessionAll() int { + return (*beeSsdb.Provider)(p).SessionAll(context.Background()) +} + +// SessionStore holds the session information which stored in ssdb +type SessionStore beeSsdb.SessionStore + +// Set the key and value +func (s *SessionStore) Set(key, value interface{}) error { + return (*beeSsdb.SessionStore)(s).Set(context.Background(), key, value) +} + +// Get return the value by the key +func (s *SessionStore) Get(key interface{}) interface{} { + return (*beeSsdb.SessionStore)(s).Get(context.Background(), key) +} + +// Delete the key in session store +func (s *SessionStore) Delete(key interface{}) error { + return (*beeSsdb.SessionStore)(s).Delete(context.Background(), key) +} + +// Flush delete all keys and values +func (s *SessionStore) Flush() error { + return (*beeSsdb.SessionStore)(s).Flush(context.Background()) +} + +// SessionID return the sessionID +func (s *SessionStore) SessionID() string { + return (*beeSsdb.SessionStore)(s).SessionID(context.Background()) +} + +// SessionRelease Store the keyvalues into ssdb +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeSsdb.SessionStore)(s).SessionRelease(context.Background(), w) +} diff --git a/pkg/adapter/session/store_adapter.go b/pkg/adapter/session/store_adapter.go new file mode 100644 index 00000000..c1a03c38 --- /dev/null +++ b/pkg/adapter/session/store_adapter.go @@ -0,0 +1,84 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type NewToOldStoreAdapter struct { + delegate session.Store +} + +func CreateNewToOldStoreAdapter(s session.Store) Store { + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +func (n *NewToOldStoreAdapter) Set(key, value interface{}) error { + return n.delegate.Set(context.Background(), key, value) +} + +func (n *NewToOldStoreAdapter) Get(key interface{}) interface{} { + return n.delegate.Get(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) Delete(key interface{}) error { + return n.delegate.Delete(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) SessionID() string { + return n.delegate.SessionID(context.Background()) +} + +func (n *NewToOldStoreAdapter) SessionRelease(w http.ResponseWriter) { + n.delegate.SessionRelease(context.Background(), w) +} + +func (n *NewToOldStoreAdapter) Flush() error { + return n.delegate.Flush(context.Background()) +} + +type oldToNewStoreAdapter struct { + delegate Store +} + +func (o *oldToNewStoreAdapter) Set(ctx context.Context, key, value interface{}) error { + return o.delegate.Set(key, value) +} + +func (o *oldToNewStoreAdapter) Get(ctx context.Context, key interface{}) interface{} { + return o.delegate.Get(key) +} + +func (o *oldToNewStoreAdapter) Delete(ctx context.Context, key interface{}) error { + return o.delegate.Delete(key) +} + +func (o *oldToNewStoreAdapter) SessionID(ctx context.Context) string { + return o.delegate.SessionID() +} + +func (o *oldToNewStoreAdapter) SessionRelease(ctx context.Context, w http.ResponseWriter) { + o.delegate.SessionRelease(w) +} + +func (o *oldToNewStoreAdapter) Flush(ctx context.Context) error { + return o.delegate.Flush() +} diff --git a/pkg/adapter/swagger/swagger.go b/pkg/adapter/swagger/swagger.go new file mode 100644 index 00000000..214959d9 --- /dev/null +++ b/pkg/adapter/swagger/swagger.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +import ( + "github.com/astaxie/beego/pkg/server/web/swagger" +) + +// Swagger list the resource +type Swagger swagger.Swagger + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information swagger.Information + +// Contact information for the exposed API. +type Contact swagger.Contact + +// License information for the exposed API. +type License swagger.License + +// Item Describes the operations available on a single path. +type Item swagger.Item + +// Operation Describes a single API operation on a path. +type Operation swagger.Operation + +// Parameter Describes a single operation parameter. +type Parameter swagger.Parameter + +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// http://swagger.io/specification/#itemsObject +type ParameterItems swagger.ParameterItems + +// Schema Object allows the definition of input and output data types. +type Schema swagger.Schema + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie swagger.Propertie + +// Response as they are returned from executing this operation. +type Response swagger.Response + +// Security Allows the definition of a security scheme that can be used by the operations +type Security swagger.Security + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag swagger.Tag + +// ExternalDocs include Additional external documentation +type ExternalDocs swagger.ExternalDocs diff --git a/pkg/adapter/template.go b/pkg/adapter/template.go new file mode 100644 index 00000000..1f943caf --- /dev/null +++ b/pkg/adapter/template.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "io" + "net/http" + + "github.com/astaxie/beego/pkg/server/web" +) + +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return web.ExecuteTemplate(wr, name, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { + return web.ExecuteViewPathTemplate(wr, name, viewPath, data) +} + +// AddFuncMap let user to register a func in the template. +func AddFuncMap(key string, fn interface{}) error { + return web.AddFuncMap(key, fn) +} + +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + +type templateFile struct { + root string + files map[string][]string +} + +// HasTemplateExt return this path contains supported template extension of beego or not. +func HasTemplateExt(paths string) bool { + return web.HasTemplateExt(paths) +} + +// AddTemplateExt add new extension for template. +func AddTemplateExt(ext string) { + web.AddTemplateExt(ext) +} + +// AddViewPath adds a new path to the supported view paths. +// Can later be used by setting a controller ViewPath to this folder +// will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + return web.AddViewPath(viewPath) +} + +// BuildTemplate will build all template files in a directory. +// it makes beego can render any template file in view directory. +func BuildTemplate(dir string, files ...string) error { + return web.BuildTemplate(dir, files...) +} + +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} + +// SetTemplateFSFunc set default filesystem function +func SetTemplateFSFunc(fnt templateFSFunc) { + web.SetTemplateFSFunc(func() http.FileSystem { + return fnt() + }) +} + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + return (*App)(web.SetViewsPath(path)) +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + return (*App)(web.SetStaticPath(url, path)) +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + return (*App)(web.DelStaticPath(url)) +} + +// AddTemplateEngine add a new templatePreProcessor which support extension +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + return (*App)(web.AddTemplateEngine(extension, func(root, path string, funcs template.FuncMap) (*template.Template, error) { + return fn(root, path, funcs) + })) +} diff --git a/pkg/adapter/templatefunc.go b/pkg/adapter/templatefunc.go new file mode 100644 index 00000000..5130d590 --- /dev/null +++ b/pkg/adapter/templatefunc.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// Substr returns the substr from start to length. +func Substr(s string, start, length int) string { + return web.Substr(s, start, length) +} + +// HTML2str returns escaping text convert from html. +func HTML2str(html string) string { + return web.HTML2str(html) +} + +// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +func DateFormat(t time.Time, layout string) (datestring string) { + return web.DateFormat(t, layout) +} + +// DateParse Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + return web.DateParse(dateString, format) +} + +// Date takes a PHP like date func to Go's time format. +func Date(t time.Time, format string) string { + return web.Date(t, format) +} + +// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. +// Whitespace is trimmed. Used by the template parser as "eq". +func Compare(a, b interface{}) (equal bool) { + return web.Compare(a, b) +} + +// CompareNot !Compare +func CompareNot(a, b interface{}) (equal bool) { + return web.CompareNot(a, b) +} + +// NotNil the same as CompareNot +func NotNil(a interface{}) (isNil bool) { + return web.NotNil(a) +} + +// GetConfig get the Appconfig +func GetConfig(returnType, key string, defaultVal interface{}) (interface{}, error) { + return web.GetConfig(returnType, key, defaultVal) +} + +// Str2html Convert string to template.HTML type. +func Str2html(raw string) template.HTML { + return web.Str2html(raw) +} + +// Htmlquote returns quoted html string. +func Htmlquote(text string) string { + return web.Htmlquote(text) +} + +// Htmlunquote returns unquoted html string. +func Htmlunquote(text string) string { + return web.Htmlunquote(text) +} + +// URLFor returns url string with another registered controller handler with params. +// usage: +// +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe +// +// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +func URLFor(endpoint string, values ...interface{}) string { + return web.URLFor(endpoint, values...) +} + +// AssetsJs returns script tag with src string. +func AssetsJs(text string) template.HTML { + return web.AssetsJs(text) +} + +// AssetsCSS returns stylesheet link tag with src string. +func AssetsCSS(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + return web.ParseForm(form, obj) +} + +// RenderForm will render object to form html. +// obj must be a struct pointer. +func RenderForm(obj interface{}) template.HTML { + return web.RenderForm(obj) +} + +// MapGet getting value from map by keys +// usage: +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } +// +// {{ map_get m "a" }} // return 1 +// {{ map_get m 1 "c" }} // return 4 +func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { + return web.MapGet(arg1, arg2...) +} diff --git a/pkg/adapter/templatefunc_test.go b/pkg/adapter/templatefunc_test.go new file mode 100644 index 00000000..f5113606 --- /dev/null +++ b/pkg/adapter/templatefunc_test.go @@ -0,0 +1,304 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "testing" + "time" +) + +func TestSubstr(t *testing.T) { + s := `012345` + if Substr(s, 0, 2) != "01" { + t.Error("should be equal") + } + if Substr(s, 0, 100) != "012345" { + t.Error("should be equal") + } + if Substr(s, 12, 100) != "012345" { + t.Error("should be equal") + } +} + +func TestHtml2str(t *testing.T) { + h := `<123> 123\n + + + \n` + if HTML2str(h) != "123\\n\n\\n" { + t.Error("should be equal") + } +} + +func TestDateFormat(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } +} + +func TestDate(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } + if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { + t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) + } + if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { + t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) + } + if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { + t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) + } +} + +func TestCompareRelated(t *testing.T) { + if !Compare("abc", "abc") { + t.Error("should be equal") + } + if Compare("abc", "aBc") { + t.Error("should be not equal") + } + if !Compare("1", 1) { + t.Error("should be equal") + } + if CompareNot("abc", "abc") { + t.Error("should be equal") + } + if !CompareNot("abc", "aBc") { + t.Error("should be not equal") + } + if !NotNil("a string") { + t.Error("should not be nil") + } +} + +func TestHtmlquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlquote(s) != h { + t.Error("should be equal") + } +} + +func TestHtmlunquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlunquote(h) != s { + t.Error("should be equal") + } +} + +func TestParseForm(t *testing.T) { + type ExtendInfo struct { + Hobby []string `form:"hobby"` + Memo string + } + + type OtherInfo struct { + Organization string `form:"organization"` + Title string `form:"title"` + ExtendInfo + } + + type user struct { + ID int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` + OtherInfo + } + + u := user{} + form := url.Values{ + "ID": []string{"1"}, + "-": []string{"1"}, + "tag": []string{"no"}, + "username": []string{"test"}, + "age": []string{"40"}, + "Email": []string{"test@gmail.com"}, + "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, + "organization": []string{"beego"}, + "title": []string{"CXO"}, + "hobby": []string{"", "Basketball", "Football"}, + "memo": []string{"nothing"}, + } + if err := ParseForm(form, u); err == nil { + t.Fatal("nothing will be changed") + } + if err := ParseForm(form, &u); err != nil { + t.Fatal(err) + } + if u.ID != 0 { + t.Errorf("ID should equal 0 but got %v", u.ID) + } + if len(u.tag) != 0 { + t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) + } + if u.Name.(string) != "test" { + t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) + } + if u.Age != 40 { + t.Errorf("Age should equal 40 but got %v", u.Age) + } + if u.Email != "test@gmail.com" { + t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) + } + if u.Intro != "I am an engineer!" { + t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) + } + if !u.StrBool { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } + if u.Organization != "beego" { + t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) + } + if u.Title != "CXO" { + t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) + } + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) + } + if len(u.Memo) != 0 { + t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) + } +} + +func TestRenderForm(t *testing.T) { + type user struct { + ID int `form:"-"` + Name interface{} `form:"username"` + Age int `form:"age,text,年龄:"` + Sex string + Email []string + Intro string `form:",textarea"` + Ignored string `form:"-"` + } + + u := user{Name: "test", Intro: "Some Text"} + output := RenderForm(u) + if output != template.HTML("") { + t.Errorf("output should be empty but got %v", output) + } + output = RenderForm(&u) + result := template.HTML( + `Name:
` + + `年龄:
` + + `Sex:
` + + `Intro: `) + if output != result { + t.Errorf("output should equal `%v` but got `%v`", result, output) + } +} + +func TestMapGet(t *testing.T) { + // test one level map + m1 := map[string]int64{ + "a": 1, + "1": 2, + } + + if res, err := MapGet(m1, "a"); err == nil { + if res.(int64) != 1 { + t.Errorf("Should return 1, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, "1"); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, 1); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 2 level map + m2 := M{ + "1": map[string]float64{ + "2": 3.5, + }, + } + + if res, err := MapGet(m2, 1, 2); err == nil { + if res.(float64) != 3.5 { + t.Errorf("Should return 3.5, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 5 level map + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ + "5": 1.2, + }, + }, + }, + }, + } + + if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { + if res.(float64) != 1.2 { + t.Errorf("Should return 1.2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // check whether element not exists in map + if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { + if res != nil { + t.Errorf("Should return nil, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } +} diff --git a/pkg/adapter/testing/client.go b/pkg/adapter/testing/client.go new file mode 100644 index 00000000..688aa6f3 --- /dev/null +++ b/pkg/adapter/testing/client.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "github.com/astaxie/beego/pkg/client/httplib/testing" +) + +var port = "" +var baseURL = "http://localhost:" + +// TestHTTPRequest beego test request client +type TestHTTPRequest testing.TestHTTPRequest + +// Get returns test client in GET method +func Get(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Get(path)) +} + +// Post returns test client in POST method +func Post(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Post(path)) +} + +// Put returns test client in PUT method +func Put(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Put(path)) +} + +// Delete returns test client in DELETE method +func Delete(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Delete(path)) +} + +// Head returns test client in HEAD method +func Head(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Head(path)) +} diff --git a/pkg/adapter/toolbox/healthcheck.go b/pkg/adapter/toolbox/healthcheck.go new file mode 100644 index 00000000..56be8089 --- /dev/null +++ b/pkg/adapter/toolbox/healthcheck.go @@ -0,0 +1,52 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package toolbox healthcheck +// +// type DatabaseCheck struct { +// } +// +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } +// +// AddHealthCheck("database",&DatabaseCheck{}) +// +// more docs: http://beego.me/docs/module/toolbox.md +package toolbox + +import ( + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +// AdminCheckList holds health checker map +// Deprecated using governor.AdminCheckList +var AdminCheckList map[string]HealthChecker + +// HealthChecker health checker interface +type HealthChecker governor.HealthChecker + +// AddHealthCheck add health checker with name string +func AddHealthCheck(name string, hc HealthChecker) { + governor.AddHealthCheck(name, hc) + AdminCheckList[name] = hc +} + +func init() { + AdminCheckList = make(map[string]HealthChecker) +} diff --git a/pkg/adapter/toolbox/profile.go b/pkg/adapter/toolbox/profile.go new file mode 100644 index 00000000..16cf80b1 --- /dev/null +++ b/pkg/adapter/toolbox/profile.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "io" + "os" + "time" + + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +var startTime = time.Now() +var pid int + +func init() { + pid = os.Getpid() +} + +// ProcessInput parse input command string +func ProcessInput(input string, w io.Writer) { + governor.ProcessInput(input, w) +} + +// MemProf record memory profile in pprof +func MemProf(w io.Writer) { + governor.MemProf(w) +} + +// GetCPUProfile start cpu profile monitor +func GetCPUProfile(w io.Writer) { + governor.GetCPUProfile(w) +} + +// PrintGCSummary print gc information to io.Writer +func PrintGCSummary(w io.Writer) { + governor.PrintGCSummary(w) +} diff --git a/pkg/adapter/toolbox/profile_test.go b/pkg/adapter/toolbox/profile_test.go new file mode 100644 index 00000000..07a20c4e --- /dev/null +++ b/pkg/adapter/toolbox/profile_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "os" + "testing" +) + +func TestProcessInput(t *testing.T) { + ProcessInput("lookup goroutine", os.Stdout) + ProcessInput("lookup heap", os.Stdout) + ProcessInput("lookup threadcreate", os.Stdout) + ProcessInput("lookup block", os.Stdout) + ProcessInput("gc summary", os.Stdout) +} diff --git a/pkg/adapter/toolbox/statistics.go b/pkg/adapter/toolbox/statistics.go new file mode 100644 index 00000000..b7d3bda9 --- /dev/null +++ b/pkg/adapter/toolbox/statistics.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Statistics struct +type Statistics web.Statistics + +// URLMap contains several statistics struct to log different data +type URLMap web.URLMap + +// AddStatistics add statistics task. +// it needs request method, request url, request controller and statistics time duration +func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { + (*web.URLMap)(m).AddStatistics(requestMethod, requestURL, requestController, requesttime) +} + +// GetMap put url statistics result in io.Writer +func (m *URLMap) GetMap() map[string]interface{} { + return (*web.URLMap)(m).GetMap() +} + +// GetMapData return all mapdata +func (m *URLMap) GetMapData() []map[string]interface{} { + return (*web.URLMap)(m).GetMapData() +} + +// StatisticsMap hosld global statistics data map +var StatisticsMap *URLMap + +func init() { + StatisticsMap = (*URLMap)(web.StatisticsMap) +} diff --git a/pkg/adapter/toolbox/statistics_test.go b/pkg/adapter/toolbox/statistics_test.go new file mode 100644 index 00000000..ac29476c --- /dev/null +++ b/pkg/adapter/toolbox/statistics_test.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "encoding/json" + "testing" + "time" +) + +func TestStatics(t *testing.T) { + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) + StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) + StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) + StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) +} diff --git a/pkg/adapter/toolbox/task.go b/pkg/adapter/toolbox/task.go new file mode 100644 index 00000000..2a6d9aa6 --- /dev/null +++ b/pkg/adapter/toolbox/task.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "context" + "sort" + "time" + + "github.com/astaxie/beego/pkg/task" +) + +// The bounds for each field. +var ( + AdminTaskList map[string]Tasker +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Schedule time taks schedule +type Schedule task.Schedule + +// TaskFunc task func type +type TaskFunc func() error + +// Tasker task interface +type Tasker interface { + GetSpec() string + GetStatus() string + Run() error + SetNext(time.Time) + GetNext() time.Time + SetPrev(time.Time) + GetPrev() time.Time +} + +// task error +type taskerr struct { + t time.Time + errinfo string +} + +// Task task struct +// Deprecated +type Task struct { + // Deprecated + Taskname string + // Deprecated + Spec *Schedule + // Deprecated + SpecStr string + // Deprecated + DoFunc TaskFunc + // Deprecated + Prev time.Time + // Deprecated + Next time.Time + // Deprecated + Errlist []*taskerr // like errtime:errinfo + // Deprecated + ErrLimit int // max length for the errlist, 0 stand for no limit + + delegate *task.Task +} + +// NewTask add new task with name, time and func +func NewTask(tname string, spec string, f TaskFunc) *Task { + + task := task.NewTask(tname, spec, func(ctx context.Context) error { + return f() + }) + return &Task{ + delegate: task, + } +} + +// GetSpec get spec string +func (t *Task) GetSpec() string { + t.initDelegate() + + return t.delegate.GetSpec(context.Background()) +} + +// GetStatus get current task status +func (t *Task) GetStatus() string { + + t.initDelegate() + + return t.delegate.GetStatus(context.Background()) +} + +// Run run all tasks +func (t *Task) Run() error { + t.initDelegate() + return t.delegate.Run(context.Background()) +} + +// SetNext set next time for this task +func (t *Task) SetNext(now time.Time) { + t.initDelegate() + t.delegate.SetNext(context.Background(), now) +} + +// GetNext get the next call time of this task +func (t *Task) GetNext() time.Time { + t.initDelegate() + return t.delegate.GetNext(context.Background()) +} + +// SetPrev set prev time of this task +func (t *Task) SetPrev(now time.Time) { + t.initDelegate() + t.delegate.SetPrev(context.Background(), now) +} + +// GetPrev get prev time of this task +func (t *Task) GetPrev() time.Time { + t.initDelegate() + return t.delegate.GetPrev(context.Background()) +} + +// six columns mean: +// second:0-59 +// minute:0-59 +// hour:1-23 +// day:1-31 +// month:1-12 +// week:0-6(0 means Sunday) + +// SetCron some signals: +// *: any time +// ,:  separate signal +//    -:duration +// /n : do as n times of time duration +// /////////////////////////////////////////////////////// +// 0/30 * * * * * every 30s +// 0 43 21 * * * 21:43 +// 0 15 05 * * *    05:15 +// 0 0 17 * * * 17:00 +// 0 0 17 * * 1 17:00 in every Monday +// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday +// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month +// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month +// 0 42 4 1 * *     4:42 on the 1st day of month +// 0 0 21 * * 1-6   21:00 from Monday to Saturday +// 0 0,10,20,30,40,50 * * * *  every 10 min duration +// 0 */10 * * * *        every 10 min duration +// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time +// 0 0 1 * * *         1:00 +// 0 0 */1 * * *        0 min of hour in 1 hour duration +// 0 0 * * * *         0 min of hour in 1 hour duration +// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 +// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month +func (t *Task) SetCron(spec string) { + t.initDelegate() + t.delegate.SetCron(spec) +} + +func (t *Task) initDelegate() { + if t.delegate == nil { + t.delegate = &task.Task{ + Taskname: t.Taskname, + Spec: (*task.Schedule)(t.Spec), + SpecStr: t.SpecStr, + DoFunc: func(ctx context.Context) error { + return t.DoFunc() + }, + Prev: t.Prev, + Next: t.Next, + ErrLimit: t.ErrLimit, + } + } +} + +// Next set schedule to next time +func (s *Schedule) Next(t time.Time) time.Time { + return (*task.Schedule)(s).Next(t) +} + +// StartTask start all tasks +func StartTask() { + task.StartTask() +} + +// StopTask stop all tasks +func StopTask() { + task.StopTask() +} + +// AddTask add task with name +func AddTask(taskname string, t Tasker) { + task.AddTask(taskname, &oldToNewAdapter{delegate: t}) +} + +// DeleteTask delete task with name +func DeleteTask(taskname string) { + task.DeleteTask(taskname) +} + +// MapSorter sort map for tasker +type MapSorter task.MapSorter + +// NewMapSorter create new tasker map +func NewMapSorter(m map[string]Tasker) *MapSorter { + + newTaskerMap := make(map[string]task.Tasker, len(m)) + + for key, value := range m { + newTaskerMap[key] = &oldToNewAdapter{ + delegate: value, + } + } + + return (*MapSorter)(task.NewMapSorter(newTaskerMap)) +} + +// Sort sort tasker map +func (ms *MapSorter) Sort() { + sort.Sort(ms) +} + +func (ms *MapSorter) Len() int { return len(ms.Keys) } +func (ms *MapSorter) Less(i, j int) bool { + if ms.Vals[i].GetNext(context.Background()).IsZero() { + return false + } + if ms.Vals[j].GetNext(context.Background()).IsZero() { + return true + } + return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) +} +func (ms *MapSorter) Swap(i, j int) { + ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] + ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] +} + +func init() { + AdminTaskList = make(map[string]Tasker) +} + +type oldToNewAdapter struct { + delegate Tasker +} + +func (o *oldToNewAdapter) GetSpec(ctx context.Context) string { + return o.delegate.GetSpec() +} + +func (o *oldToNewAdapter) GetStatus(ctx context.Context) string { + return o.delegate.GetStatus() +} + +func (o *oldToNewAdapter) Run(ctx context.Context) error { + return o.delegate.Run() +} + +func (o *oldToNewAdapter) SetNext(ctx context.Context, t time.Time) { + o.delegate.SetNext(t) +} + +func (o *oldToNewAdapter) GetNext(ctx context.Context) time.Time { + return o.delegate.GetNext() +} + +func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { + o.delegate.SetPrev(t) +} + +func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { + return o.delegate.GetPrev() +} diff --git a/pkg/adapter/toolbox/task_test.go b/pkg/adapter/toolbox/task_test.go new file mode 100644 index 00000000..596bc9c5 --- /dev/null +++ b/pkg/adapter/toolbox/task_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestParse(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + err := tk.Run() + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + time.Sleep(6 * time.Second) + StopTask() +} + +func TestSpec(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + + AddTask("tk1", tk1) + AddTask("tk2", tk2) + AddTask("tk3", tk3) + StartTask() + defer StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} diff --git a/pkg/adapter/tree.go b/pkg/adapter/tree.go new file mode 100644 index 00000000..2e3cd0d0 --- /dev/null +++ b/pkg/adapter/tree.go @@ -0,0 +1,49 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Tree has three elements: FixRouter/wildcard/leaves +// fixRouter stores Fixed Router +// wildcard stores params +// leaves store the endpoint information +type Tree web.Tree + +// NewTree return a new Tree +func NewTree() *Tree { + return (*Tree)(web.NewTree()) +} + +// AddTree will add tree to the exist Tree +// prefix should has no params +func (t *Tree) AddTree(prefix string, tree *Tree) { + (*web.Tree)(t).AddTree(prefix, (*web.Tree)(tree)) +} + +// AddRouter call addseg function +func (t *Tree) AddRouter(pattern string, runObject interface{}) { + (*web.Tree)(t).AddRouter(pattern, runObject) +} + +// Match router to runObject & params +func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { + return (*web.Tree)(t).Match(pattern, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/tree_test.go b/pkg/adapter/tree_test.go new file mode 100644 index 00000000..309ed072 --- /dev/null +++ b/pkg/adapter/tree_test.go @@ -0,0 +1,249 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +type testinfo struct { + url string + requesturl string + params map[string]string +} + +var routers []testinfo + +func init() { + routers = make([]testinfo, 0) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) + routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) + routers = append(routers, testinfo{"/", "/", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) + routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) + routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) + routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) + routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) + routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) + routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", + "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", + map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) + routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", + "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", + map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) + routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) + routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + tr := NewTree() + tr.AddRouter(r.url, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requesturl, ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) + } + } + } + } +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} diff --git a/pkg/adapter/utils/caller.go b/pkg/adapter/utils/caller.go new file mode 100644 index 00000000..d4fcc456 --- /dev/null +++ b/pkg/adapter/utils/caller.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetFuncName get function name +func GetFuncName(i interface{}) string { + return utils.GetFuncName(i) +} diff --git a/pkg/adapter/utils/caller_test.go b/pkg/adapter/utils/caller_test.go new file mode 100644 index 00000000..0675f0aa --- /dev/null +++ b/pkg/adapter/utils/caller_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "strings" + "testing" +) + +func TestGetFuncName(t *testing.T) { + name := GetFuncName(TestGetFuncName) + t.Log(name) + if !strings.HasSuffix(name, ".TestGetFuncName") { + t.Error("get func name error") + } +} diff --git a/pkg/adapter/utils/captcha/LICENSE b/pkg/adapter/utils/captcha/LICENSE new file mode 100644 index 00000000..0ad73ae0 --- /dev/null +++ b/pkg/adapter/utils/captcha/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Dmitry Chestnykh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/adapter/utils/captcha/README.md b/pkg/adapter/utils/captcha/README.md new file mode 100644 index 00000000..dbc2026b --- /dev/null +++ b/pkg/adapter/utils/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplName = "index.tpl" +} + +func (this *MainController) Post() { + this.TplName = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +
+ {{create_captcha}} + +
+``` diff --git a/pkg/adapter/utils/captcha/captcha.go b/pkg/adapter/utils/captcha/captcha.go new file mode 100644 index 00000000..faadc8bf --- /dev/null +++ b/pkg/adapter/utils/captcha/captcha.go @@ -0,0 +1,124 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package captcha implements generation and verification of image CAPTCHAs. +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplName = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
+// {{create_captcha}} +// +//
+// ``` +package captcha + +import ( + "html/template" + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/captcha" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/pkg/adapter/context" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + // default captcha attributes + challengeNums = 6 + expiration = 600 * time.Second + fieldIDName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + defaultURLPrefix = "/captcha/" +) + +// Captcha struct +type Captcha captcha.Captcha + +// Handler beego filter handler for serve captcha image +func (c *Captcha) Handler(ctx *context.Context) { + (*captcha.Captcha)(c).Handler((*beecontext.Context)(ctx)) +} + +// CreateCaptchaHTML template func for output html +func (c *Captcha) CreateCaptchaHTML() template.HTML { + return (*captcha.Captcha)(c).CreateCaptchaHTML() +} + +// CreateCaptcha create a new captcha id +func (c *Captcha) CreateCaptcha() (string, error) { + return (*captcha.Captcha)(c).CreateCaptcha() +} + +// VerifyReq verify from a request +func (c *Captcha) VerifyReq(req *http.Request) bool { + return (*captcha.Captcha)(c).VerifyReq(req) +} + +// Verify direct verify id and challenge string +func (c *Captcha) Verify(id string, challenge string) (success bool) { + return (*captcha.Captcha)(c).Verify(id, challenge) +} + +// NewCaptcha create a new captcha.Captcha +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewCaptcha(urlPrefix, store)) +} + +// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a template func for output html +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewWithFilter(urlPrefix, store)) +} diff --git a/pkg/adapter/utils/captcha/image.go b/pkg/adapter/utils/captcha/image.go new file mode 100644 index 00000000..9979db84 --- /dev/null +++ b/pkg/adapter/utils/captcha/image.go @@ -0,0 +1,35 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "io" + + "github.com/astaxie/beego/pkg/server/web/captcha" +) + +// Image struct +type Image captcha.Image + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + return (*Image)(captcha.NewImage(digits, width, height)) +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + return (*captcha.Image)(m).WriteTo(w) +} diff --git a/pkg/adapter/utils/captcha/image_test.go b/pkg/adapter/utils/captcha/image_test.go new file mode 100644 index 00000000..bce2134a --- /dev/null +++ b/pkg/adapter/utils/captcha/image_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/utils" +) + +const ( + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/pkg/adapter/utils/debug.go b/pkg/adapter/utils/debug.go new file mode 100644 index 00000000..d39f3d3e --- /dev/null +++ b/pkg/adapter/utils/debug.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Display print the data in console +func Display(data ...interface{}) { + utils.Display(data...) +} + +// GetDisplayString return data print string +func GetDisplayString(data ...interface{}) string { + return utils.GetDisplayString(data...) +} + +// Stack get stack bytes +func Stack(skip int, indent string) []byte { + return utils.Stack(skip, indent) +} diff --git a/pkg/adapter/utils/debug_test.go b/pkg/adapter/utils/debug_test.go new file mode 100644 index 00000000..efb8924e --- /dev/null +++ b/pkg/adapter/utils/debug_test.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +type mytype struct { + next *mytype + prev *mytype +} + +func TestPrint(t *testing.T) { + Display("v1", 1, "v2", 2, "v3", 3) +} + +func TestPrintPoint(t *testing.T) { + var v1 = new(mytype) + var v2 = new(mytype) + + v1.prev = nil + v1.next = v2 + + v2.prev = v1 + v2.next = nil + + Display("v1", v1, "v2", v2) +} + +func TestPrintString(t *testing.T) { + str := GetDisplayString("v1", 1, "v2", 2) + println(str) +} diff --git a/pkg/adapter/utils/file.go b/pkg/adapter/utils/file.go new file mode 100644 index 00000000..8979389e --- /dev/null +++ b/pkg/adapter/utils/file.go @@ -0,0 +1,47 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// SelfPath gets compiled executable file absolute path +func SelfPath() string { + return utils.SelfPath() +} + +// SelfDir gets compiled executable file directory +func SelfDir() string { + return utils.SelfDir() +} + +// FileExists reports whether the named file or directory exists. +func FileExists(name string) bool { + return utils.FileExists(name) +} + +// SearchFile Search a file in paths. +// this is often used in search config file in /etc ~/ +func SearchFile(filename string, paths ...string) (fullpath string, err error) { + return utils.SearchFile(filename, paths...) +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(patten string, filename string) (lines []string, err error) { + return utils.GrepFile(patten, filename) +} diff --git a/pkg/adapter/utils/mail.go b/pkg/adapter/utils/mail.go new file mode 100644 index 00000000..35a58756 --- /dev/null +++ b/pkg/adapter/utils/mail.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "io" + + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Email is the type used for email messages +type Email utils.Email + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment utils.Attachment + +// NewEMail create new Email struct with config json. +// config json is followed from Email struct fields. +func NewEMail(config string) *Email { + return (*Email)(utils.NewEMail(config)) +} + +// Bytes Make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + return (*utils.Email)(e).Bytes() +} + +// AttachFile Add attach file to the send mail +func (e *Email) AttachFile(args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).AttachFile(args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Attach is used to attach content from an io.Reader to the email. +// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. +func (e *Email) Attach(r io.Reader, filename string, args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).Attach(r, filename, args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Send will send out the mail +func (e *Email) Send() error { + return (*utils.Email)(e).Send() +} diff --git a/pkg/adapter/utils/mail_test.go b/pkg/adapter/utils/mail_test.go new file mode 100644 index 00000000..c38356a2 --- /dev/null +++ b/pkg/adapter/utils/mail_test.go @@ -0,0 +1,41 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

Fancy Html is supported, too!

" + mail.AttachFile("/Users/astaxie/github/beego/beego.go") + mail.Send() +} diff --git a/pkg/adapter/utils/pagination/controller.go b/pkg/adapter/utils/pagination/controller.go new file mode 100644 index 00000000..a908d8b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/pagination" +) + +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). +func SetPaginator(ctx *context.Context, per int, nums int64) (paginator *Paginator) { + return (*Paginator)(pagination.SetPaginator((*beecontext.Context)(ctx), per, nums)) +} diff --git a/pkg/adapter/utils/pagination/doc.go b/pkg/adapter/utils/pagination/doc.go new file mode 100644 index 00000000..9abc6d78 --- /dev/null +++ b/pkg/adapter/utils/pagination/doc.go @@ -0,0 +1,58 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/pkg/adapter/utils/pagination/paginator.go b/pkg/adapter/utils/pagination/paginator.go new file mode 100644 index 00000000..4bd4a1b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/paginator.go @@ -0,0 +1,112 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" +) + +// Paginator within the state of a http request. +type Paginator pagination.Paginator + +// PageNums Returns the total number of pages. +func (p *Paginator) PageNums() int { + return (*pagination.Paginator)(p).PageNums() +} + +// Nums Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return (*pagination.Paginator)(p).Nums() +} + +// SetNums Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + (*pagination.Paginator)(p).SetNums(nums) +} + +// Page Returns the current page. +func (p *Paginator) Page() int { + return (*pagination.Paginator)(p).Page() +} + +// Pages Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + return (*pagination.Paginator)(p).Pages() +} + +// PageLink Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + return (*pagination.Paginator)(p).PageLink(page) +} + +// PageLinkPrev Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + return (*pagination.Paginator)(p).PageLinkPrev() +} + +// PageLinkNext Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + return (*pagination.Paginator)(p).PageLinkNext() +} + +// PageLinkFirst Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return (*pagination.Paginator)(p).PageLinkFirst() +} + +// PageLinkLast Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return (*pagination.Paginator)(p).PageLinkLast() +} + +// HasPrev Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return (*pagination.Paginator)(p).HasPrev() +} + +// HasNext Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return (*pagination.Paginator)(p).HasNext() +} + +// IsActive Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return (*pagination.Paginator)(p).IsActive(page) +} + +// Offset Returns the current offset. +func (p *Paginator) Offset() int { + return (*pagination.Paginator)(p).Offset() +} + +// HasPages Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return (*pagination.Paginator)(p).HasPages() +} + +// NewPaginator Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + return (*Paginator)(pagination.NewPaginator(req, per, nums)) +} diff --git a/pkg/adapter/utils/rand.go b/pkg/adapter/utils/rand.go new file mode 100644 index 00000000..ae415cf3 --- /dev/null +++ b/pkg/adapter/utils/rand.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + return utils.RandomCreateBytes(n, alphabets...) +} diff --git a/pkg/adapter/utils/rand_test.go b/pkg/adapter/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/pkg/adapter/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +} diff --git a/pkg/adapter/utils/safemap.go b/pkg/adapter/utils/safemap.go new file mode 100644 index 00000000..13e7bb46 --- /dev/null +++ b/pkg/adapter/utils/safemap.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// BeeMap is a map with lock +type BeeMap utils.BeeMap + +// NewBeeMap return new safemap +func NewBeeMap() *BeeMap { + return (*BeeMap)(utils.NewBeeMap()) +} + +// Get from maps return the k's value +func (m *BeeMap) Get(k interface{}) interface{} { + return (*utils.BeeMap)(m).Get(k) +} + +// Set Maps the given key and value. Returns false +// if the key is already in the map and changes nothing. +func (m *BeeMap) Set(k interface{}, v interface{}) bool { + return (*utils.BeeMap)(m).Set(k, v) +} + +// Check Returns true if k is exist in the map. +func (m *BeeMap) Check(k interface{}) bool { + return (*utils.BeeMap)(m).Check(k) +} + +// Delete the given key and value. +func (m *BeeMap) Delete(k interface{}) { + (*utils.BeeMap)(m).Delete(k) +} + +// Items returns all items in safemap. +func (m *BeeMap) Items() map[interface{}]interface{} { + return (*utils.BeeMap)(m).Items() +} + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + return (*utils.BeeMap)(m).Count() +} diff --git a/pkg/adapter/utils/safemap_test.go b/pkg/adapter/utils/safemap_test.go new file mode 100644 index 00000000..65085195 --- /dev/null +++ b/pkg/adapter/utils/safemap_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +var safeMap *BeeMap + +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + safeMap = NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestReSet(t *testing.T) { + safeMap := NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } + // set diff value + if ok := safeMap.Set("astaxie", -1); !ok { + t.Error("expected", true, "got", false) + } + + // set same value + if ok := safeMap.Set("astaxie", -1); ok { + t.Error("expected", false, "got", true) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestItems(t *testing.T) { + safeMap := NewBeeMap() + safeMap.Set("astaxie", "hello") + for k, v := range safeMap.Items() { + key := k.(string) + value := v.(string) + if key != "astaxie" { + t.Error("expected the key should be astaxie") + } + if value != "hello" { + t.Error("expected the value should be hello") + } + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) + } +} diff --git a/pkg/adapter/utils/slice.go b/pkg/adapter/utils/slice.go new file mode 100644 index 00000000..24d19ad2 --- /dev/null +++ b/pkg/adapter/utils/slice.go @@ -0,0 +1,101 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +type reducetype func(interface{}) interface{} +type filtertype func(interface{}) bool + +// InSlice checks given string in string slice or not. +func InSlice(v string, sl []string) bool { + return utils.InSlice(v, sl) +} + +// InSliceIface checks given interface in interface slice. +func InSliceIface(v interface{}, sl []interface{}) bool { + return utils.InSliceIface(v, sl) +} + +// SliceRandList generate an int slice from min to max. +func SliceRandList(min, max int) []int { + return utils.SliceRandList(min, max) +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + return utils.SliceMerge(slice1, slice2) +} + +// SliceReduce generates a new slice after parsing every value by reduce function +func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { + return utils.SliceReduce(slice, func(i interface{}) interface{} { + return a(i) + }) +} + +// SliceRand returns random one from slice. +func SliceRand(a []interface{}) (b interface{}) { + return utils.SliceRand(a) +} + +// SliceSum sums all values in int64 slice. +func SliceSum(intslice []int64) (sum int64) { + return utils.SliceSum(intslice) +} + +// SliceFilter generates a new slice after filter function. +func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { + return utils.SliceFilter(slice, func(i interface{}) bool { + return a(i) + }) +} + +// SliceDiff returns diff slice of slice1 - slice2. +func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceDiff(slice1, slice2) +} + +// SliceIntersect returns slice that are present in all the slice1 and slice2. +func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceIntersect(slice1, slice2) +} + +// SliceChunk separates one slice to some sized slice. +func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { + return utils.SliceChunk(slice, size) +} + +// SliceRange generates a new slice from begin to end with step duration of int64 number. +func SliceRange(start, end, step int64) (intslice []int64) { + return utils.SliceRange(start, end, step) +} + +// SlicePad prepends size number of val into slice. +func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { + return utils.SlicePad(slice, size, val) +} + +// SliceUnique cleans repeated values in slice. +func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { + return utils.SliceUnique(slice) +} + +// SliceShuffle shuffles a slice. +func SliceShuffle(slice []interface{}) []interface{} { + return utils.SliceShuffle(slice) +} diff --git a/pkg/adapter/utils/slice_test.go b/pkg/adapter/utils/slice_test.go new file mode 100644 index 00000000..142dec96 --- /dev/null +++ b/pkg/adapter/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInSlice(t *testing.T) { + sl := []string{"A", "b"} + if !InSlice("A", sl) { + t.Error("should be true") + } + if InSlice("B", sl) { + t.Error("should be false") + } +} diff --git a/pkg/adapter/utils/utils.go b/pkg/adapter/utils/utils.go new file mode 100644 index 00000000..1f3bcd31 --- /dev/null +++ b/pkg/adapter/utils/utils.go @@ -0,0 +1,10 @@ +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + return utils.GetGOPATHs() +} diff --git a/pkg/adapter/validation/util.go b/pkg/adapter/validation/util.go new file mode 100644 index 00000000..729712e0 --- /dev/null +++ b/pkg/adapter/validation/util.go @@ -0,0 +1,62 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "reflect" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +const ( + // ValidTag struct tag + ValidTag = validation.ValidTag + + LabelTag = validation.LabelTag +) + +var ( + ErrInt64On32 = validation.ErrInt64On32 +) + +// CustomFunc is for custom validate function +type CustomFunc func(v *Validation, obj interface{}, key string) + +// AddCustomFunc Add a custom function to validation +// The name can not be: +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// If the name is same with exists function, it will replace the origin valid function +func AddCustomFunc(name string, f CustomFunc) error { + return validation.AddCustomFunc(name, func(v *validation.Validation, obj interface{}, key string) { + f((*Validation)(v), obj, key) + }) +} + +// ValidFunc Valid function type +type ValidFunc validation.ValidFunc + +// Funcs Validate function map +type Funcs validation.Funcs + +// Call validate values with named type string +func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + return (validation.Funcs(f)).Call(name, params...) +} diff --git a/pkg/adapter/validation/validation.go b/pkg/adapter/validation/validation.go new file mode 100644 index 00000000..1cdb8dda --- /dev/null +++ b/pkg/adapter/validation/validation.go @@ -0,0 +1,274 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package validation for validations +// +// import ( +// "github.com/astaxie/beego/validation" +// "log" +// ) +// +// type User struct { +// Name string +// Age int +// } +// +// func main() { +// u := User{"man", 40} +// valid := validation.Validation{} +// valid.Required(u.Name, "name") +// valid.MaxSize(u.Name, 15, "nameMax") +// valid.Range(u.Age, 0, 140, "age") +// if valid.HasErrors() { +// // validation does not pass +// // print invalid message +// for _, err := range valid.Errors { +// log.Println(err.Key, err.Message) +// } +// } +// // or use like this +// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { +// log.Println(v.Error.Key, v.Error.Message) +// } +// } +// +// more info: http://beego.me/docs/mvc/controller/validation.md +package validation + +import ( + "fmt" + "regexp" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// ValidFormer valid interface +type ValidFormer interface { + Valid(*Validation) +} + +// Error show the error +type Error validation.Error + +// String Returns the Message. +func (e *Error) String() string { + if e == nil { + return "" + } + return e.Message +} + +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + +// Result is returned from every validation method. +// It provides an indication of success, and a pointer to the Error (if any). +type Result validation.Result + +// Key Get Result by given key string. +func (r *Result) Key(key string) *Result { + if r.Error != nil { + r.Error.Key = key + } + return r +} + +// Message Set Result message by string or format string with args +func (r *Result) Message(message string, args ...interface{}) *Result { + if r.Error != nil { + if len(args) == 0 { + r.Error.Message = message + } else { + r.Error.Message = fmt.Sprintf(message, args...) + } + } + return r +} + +// A Validation context manages data validation and error messages. +type Validation validation.Validation + +// Clear Clean all ValidationError. +func (v *Validation) Clear() { + (*validation.Validation)(v).Clear() +} + +// HasErrors Has ValidationError nor not. +func (v *Validation) HasErrors() bool { + return (*validation.Validation)(v).HasErrors() +} + +// ErrorMap Return the errors mapped by key. +// If there are multiple validation errors associated with a single key, the +// first one "wins". (Typically the first validation will be the more basic). +func (v *Validation) ErrorMap() map[string][]*Error { + newErrors := (*validation.Validation)(v).ErrorMap() + res := make(map[string][]*Error, len(newErrors)) + for n, es := range newErrors { + errs := make([]*Error, 0, len(es)) + + for _, e := range es { + errs = append(errs, (*Error)(e)) + } + + res[n] = errs + } + return res +} + +// Error Add an error to the validation context. +func (v *Validation) Error(message string, args ...interface{}) *Result { + return (*Result)((*validation.Validation)(v).Error(message, args...)) +} + +// Required Test that the argument is non-nil and non-empty (if string or list) +func (v *Validation) Required(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Required(obj, key)) +} + +// Min Test that the obj is greater than min if obj's type is int +func (v *Validation) Min(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).Min(obj, min, key)) +} + +// Max Test that the obj is less than max if obj's type is int +func (v *Validation) Max(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Max(obj, max, key)) +} + +// Range Test that the obj is between mni and max if obj's type is int +func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Range(obj, min, max, key)) +} + +// MinSize Test that the obj is longer than min size if type is string or slice +func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).MinSize(obj, min, key)) +} + +// MaxSize Test that the obj is shorter than max size if type is string or slice +func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).MaxSize(obj, max, key)) +} + +// Length Test that the obj is same length to n if type is string or slice +func (v *Validation) Length(obj interface{}, n int, key string) *Result { + return (*Result)((*validation.Validation)(v).Length(obj, n, key)) +} + +// Alpha Test that the obj is [a-zA-Z] if type is string +func (v *Validation) Alpha(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Alpha(obj, key)) +} + +// Numeric Test that the obj is [0-9] if type is string +func (v *Validation) Numeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Numeric(obj, key)) +} + +// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string +func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaNumeric(obj, key)) +} + +// Match Test that the obj matches regexp if type is string +func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).Match(obj, regex, key)) +} + +// NoMatch Test that the obj doesn't match regexp if type is string +func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).NoMatch(obj, regex, key)) +} + +// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string +func (v *Validation) AlphaDash(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaDash(obj, key)) +} + +// Email Test that the obj is email address if type is string +func (v *Validation) Email(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Email(obj, key)) +} + +// IP Test that the obj is IP address if type is string +func (v *Validation) IP(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).IP(obj, key)) +} + +// Base64 Test that the obj is base64 encoded if type is string +func (v *Validation) Base64(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Base64(obj, key)) +} + +// Mobile Test that the obj is chinese mobile number if type is string +func (v *Validation) Mobile(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Mobile(obj, key)) +} + +// Tel Test that the obj is chinese telephone number if type is string +func (v *Validation) Tel(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Tel(obj, key)) +} + +// Phone Test that the obj is chinese mobile or telephone number if type is string +func (v *Validation) Phone(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Phone(obj, key)) +} + +// ZipCode Test that the obj is chinese zip code if type is string +func (v *Validation) ZipCode(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).ZipCode(obj, key)) +} + +// key must like aa.bb.cc or aa.bb. +// AddError adds independent error message for the provided key +func (v *Validation) AddError(key, message string) { + (*validation.Validation)(v).AddError(key, message) +} + +// SetError Set error message for one field in ValidationError +func (v *Validation) SetError(fieldName string, errMsg string) *Error { + return (*Error)((*validation.Validation)(v).SetError(fieldName, errMsg)) +} + +// Check Apply a group of validators to a field, in order, and return the +// ValidationResult from the first one that fails, or the last one that +// succeeds. +func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { + vldts := make([]validation.Validator, 0, len(checks)) + for _, v := range checks { + vldts = append(vldts, validation.Validator(v)) + } + return (*Result)((*validation.Validation)(v).Check(obj, vldts...)) +} + +// Valid Validate a struct. +// the obj parameter must be a struct or a struct pointer +func (v *Validation) Valid(obj interface{}) (b bool, err error) { + return (*validation.Validation)(v).Valid(obj) +} + +// RecursiveValid Recursively validate a struct. +// Step1: Validate by v.Valid +// Step2: If pass on step1, then reflect obj's fields +// Step3: Do the Recursively validation to all struct or struct pointer fields +func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { + return (*validation.Validation)(v).RecursiveValid(objc) +} + +func (v *Validation) CanSkipAlso(skipFunc string) { + (*validation.Validation)(v).CanSkipAlso(skipFunc) +} diff --git a/pkg/adapter/validation/validation_test.go b/pkg/adapter/validation/validation_test.go new file mode 100644 index 00000000..b4b5b1b6 --- /dev/null +++ b/pkg/adapter/validation/validation_test.go @@ -0,0 +1,609 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "regexp" + "testing" + "time" +) + +func TestRequired(t *testing.T) { + valid := Validation{} + + if valid.Required(nil, "nil").Ok { + t.Error("nil object should be false") + } + if !valid.Required(true, "bool").Ok { + t.Error("Bool value should always return true") + } + if !valid.Required(false, "bool").Ok { + t.Error("Bool value should always return true") + } + if valid.Required("", "string").Ok { + t.Error("\"'\" string should be false") + } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } + if !valid.Required("astaxie", "string").Ok { + t.Error("string should be true") + } + if valid.Required(0, "zero").Ok { + t.Error("Integer should not be equal 0") + } + if !valid.Required(1, "int").Ok { + t.Error("Integer except 0 should be true") + } + if !valid.Required(time.Now(), "time").Ok { + t.Error("time should be true") + } + if valid.Required([]string{}, "emptySlice").Ok { + t.Error("empty slice should be false") + } + if !valid.Required([]interface{}{"ok"}, "slice").Ok { + t.Error("slice should be true") + } +} + +func TestMin(t *testing.T) { + valid := Validation{} + + if valid.Min(-1, 0, "min0").Ok { + t.Error("-1 is less than the minimum value of 0 should be false") + } + if !valid.Min(1, 0, "min0").Ok { + t.Error("1 is greater or equal than the minimum value of 0 should be true") + } +} + +func TestMax(t *testing.T) { + valid := Validation{} + + if valid.Max(1, 0, "max0").Ok { + t.Error("1 is greater than the minimum value of 0 should be false") + } + if !valid.Max(-1, 0, "max0").Ok { + t.Error("-1 is less or equal than the maximum value of 0 should be true") + } +} + +func TestRange(t *testing.T) { + valid := Validation{} + + if valid.Range(-1, 0, 1, "range0_1").Ok { + t.Error("-1 is between 0 and 1 should be false") + } + if !valid.Range(1, 0, 1, "range0_1").Ok { + t.Error("1 is between 0 and 1 should be true") + } +} + +func TestMinSize(t *testing.T) { + valid := Validation{} + + if valid.MinSize("", 1, "minSize1").Ok { + t.Error("the length of \"\" is less than the minimum value of 1 should be false") + } + if !valid.MinSize("ok", 1, "minSize1").Ok { + t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") + } + if valid.MinSize([]string{}, 1, "minSize1").Ok { + t.Error("the length of empty slice is less than the minimum value of 1 should be false") + } + if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { + t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") + } +} + +func TestMaxSize(t *testing.T) { + valid := Validation{} + + if valid.MaxSize("ok", 1, "maxSize1").Ok { + t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize("", 1, "maxSize1").Ok { + t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") + } + if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { + t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { + t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") + } +} + +func TestLength(t *testing.T) { + valid := Validation{} + + if valid.Length("", 1, "length1").Ok { + t.Error("the length of \"\" must equal 1 should be false") + } + if !valid.Length("1", 1, "length1").Ok { + t.Error("the length of \"1\" must equal 1 should be true") + } + if valid.Length([]string{}, 1, "length1").Ok { + t.Error("the length of empty slice must equal 1 should be false") + } + if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { + t.Error("the length of [\"ok\"] must equal 1 should be true") + } +} + +func TestAlpha(t *testing.T) { + valid := Validation{} + + if valid.Alpha("a,1-@ $", "alpha").Ok { + t.Error("\"a,1-@ $\" are valid alpha characters should be false") + } + if !valid.Alpha("abCD", "alpha").Ok { + t.Error("\"abCD\" are valid alpha characters should be true") + } +} + +func TestNumeric(t *testing.T) { + valid := Validation{} + + if valid.Numeric("a,1-@ $", "numeric").Ok { + t.Error("\"a,1-@ $\" are valid numeric characters should be false") + } + if !valid.Numeric("1234", "numeric").Ok { + t.Error("\"1234\" are valid numeric characters should be true") + } +} + +func TestAlphaNumeric(t *testing.T) { + valid := Validation{} + + if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") + } + if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { + t.Error("\"1234aB\" are valid alpha or numeric characters should be true") + } +} + +func TestMatch(t *testing.T) { + valid := Validation{} + + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") + } + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") + } +} + +func TestNoMatch(t *testing.T) { + valid := Validation{} + + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") + } + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") + } +} + +func TestAlphaDash(t *testing.T) { + valid := Validation{} + + if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") + } + if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { + t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") + } +} + +func TestEmail(t *testing.T) { + valid := Validation{} + + if valid.Email("not@a email", "email").Ok { + t.Error("\"not@a email\" is a valid email address should be false") + } + if !valid.Email("suchuangji@gmail.com", "email").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") + } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } +} + +func TestIP(t *testing.T) { + valid := Validation{} + + if valid.IP("11.255.255.256", "IP").Ok { + t.Error("\"11.255.255.256\" is a valid ip address should be false") + } + if !valid.IP("01.11.11.11", "IP").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") + } +} + +func TestBase64(t *testing.T) { + valid := Validation{} + + if valid.Base64("suchuangji@gmail.com", "base64").Ok { + t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") + } + if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { + t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") + } +} + +func TestMobile(t *testing.T) { + valid := Validation{} + + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } + } +} + +func TestTel(t *testing.T) { + valid := Validation{} + + if valid.Tel("222-00008888", "telephone").Ok { + t.Error("\"222-00008888\" is a valid telephone number should be false") + } + if !valid.Tel("022-70008888", "telephone").Ok { + t.Error("\"022-70008888\" is a valid telephone number should be true") + } + if !valid.Tel("02270008888", "telephone").Ok { + t.Error("\"02270008888\" is a valid telephone number should be true") + } + if !valid.Tel("70008888", "telephone").Ok { + t.Error("\"70008888\" is a valid telephone number should be true") + } +} + +func TestPhone(t *testing.T) { + valid := Validation{} + + if valid.Phone("222-00008888", "phone").Ok { + t.Error("\"222-00008888\" is a valid phone number should be false") + } + if !valid.Mobile("+8614700008888", "phone").Ok { + t.Error("\"+8614700008888\" is a valid phone number should be true") + } + if !valid.Tel("02270008888", "phone").Ok { + t.Error("\"02270008888\" is a valid phone number should be true") + } +} + +func TestZipCode(t *testing.T) { + valid := Validation{} + + if valid.ZipCode("", "zipcode").Ok { + t.Error("\"00008888\" is a valid zipcode should be false") + } + if !valid.ZipCode("536000", "zipcode").Ok { + t.Error("\"536000\" is a valid zipcode should be true") + } +} + +func TestValid(t *testing.T) { + type user struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + valid := Validation{} + + u := user{Name: "test@/test/;com", Age: 40} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Error("validation should be passed") + } + + uptr := &user{Name: "test", Age: 40} + valid.Clear() + b, err = valid.Valid(uptr) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Name.Match" { + t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + } + + u = user{Name: "test@/test/;com", Age: 180} + valid.Clear() + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) + } +} + +func TestRecursiveValid(t *testing.T) { + type User struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + type AnonymouseUser struct { + ID2 int + Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age2 int `valid:"Required;Range(1, 140)"` + } + + type Account struct { + Password string `valid:"Required"` + U User + AnonymouseUser + } + valid := Validation{} + + u := Account{Password: "abc123_", U: User{}} + b, err := valid.RecursiveValid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } +} + +func TestSkipValid(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + + IP string `valid:"IP"` + ReqIP string `valid:"Required;IP"` + + Mobile string `valid:"Mobile"` + ReqMobile string `valid:"Required;Mobile"` + + Tel string `valid:"Tel"` + ReqTel string `valid:"Required;Tel"` + + Phone string `valid:"Phone"` + ReqPhone string `valid:"Required;Phone"` + + ZipCode string `valid:"ZipCode"` + ReqZipCode string `valid:"Required;ZipCode"` + } + + u := User{ + ReqEmail: "a@a.com", + ReqIP: "127.0.0.1", + ReqMobile: "18888888888", + ReqTel: "02088888888", + ReqPhone: "02088888888", + ReqZipCode: "510000", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } +} + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} diff --git a/pkg/adapter/validation/validators.go b/pkg/adapter/validation/validators.go new file mode 100644 index 00000000..1a063749 --- /dev/null +++ b/pkg/adapter/validation/validators.go @@ -0,0 +1,512 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "sync" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty +var CanSkipFuncs = validation.CanSkipFuncs + +// MessageTmpls store commond validate template +var MessageTmpls = map[string]string{ + "Required": "Can not be empty", + "Min": "Minimum is %d", + "Max": "Maximum is %d", + "Range": "Range is %d to %d", + "MinSize": "Minimum size is %d", + "MaxSize": "Maximum size is %d", + "Length": "Required length is %d", + "Alpha": "Must be valid alpha characters", + "Numeric": "Must be valid numeric characters", + "AlphaNumeric": "Must be valid alpha or numeric characters", + "Match": "Must match %s", + "NoMatch": "Must not match %s", + "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", + "Email": "Must be a valid email address", + "IP": "Must be a valid ip address", + "Base64": "Must be valid base64 characters", + "Mobile": "Must be valid mobile number", + "Tel": "Must be valid telephone number", + "Phone": "Must be valid telephone or mobile phone number", + "ZipCode": "Must be valid zipcode", +} + +var once sync.Once + +// SetDefaultMessage set default messages +// if not set, the default messages are +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", +func SetDefaultMessage(msg map[string]string) { + validation.SetDefaultMessage(msg) +} + +// Validator interface +type Validator interface { + IsSatisfied(interface{}) bool + DefaultMessage() string + GetKey() string + GetLimitValue() interface{} +} + +// Required struct +type Required validation.Required + +// IsSatisfied judge whether obj has value +func (r Required) IsSatisfied(obj interface{}) bool { + return validation.Required(r).IsSatisfied(obj) +} + +// DefaultMessage return the default error message +func (r Required) DefaultMessage() string { + return validation.Required(r).DefaultMessage() +} + +// GetKey return the r.Key +func (r Required) GetKey() string { + return validation.Required(r).GetKey() +} + +// GetLimitValue return nil now +func (r Required) GetLimitValue() interface{} { + return validation.Required(r).GetLimitValue() +} + +// Min check struct +type Min validation.Min + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Min) IsSatisfied(obj interface{}) bool { + return validation.Min(m).IsSatisfied(obj) +} + +// DefaultMessage return the default min error message +func (m Min) DefaultMessage() string { + return validation.Min(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Min) GetKey() string { + return validation.Min(m).GetKey() +} + +// GetLimitValue return the limit value, Min +func (m Min) GetLimitValue() interface{} { + return validation.Min(m).GetLimitValue() +} + +// Max validate struct +type Max validation.Max + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Max) IsSatisfied(obj interface{}) bool { + return validation.Max(m).IsSatisfied(obj) +} + +// DefaultMessage return the default max error message +func (m Max) DefaultMessage() string { + return validation.Max(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Max) GetKey() string { + return validation.Max(m).GetKey() +} + +// GetLimitValue return the limit value, Max +func (m Max) GetLimitValue() interface{} { + return validation.Max(m).GetLimitValue() +} + +// Range Requires an integer to be within Min, Max inclusive. +type Range validation.Range + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (r Range) IsSatisfied(obj interface{}) bool { + return validation.Range(r).IsSatisfied(obj) +} + +// DefaultMessage return the default Range error message +func (r Range) DefaultMessage() string { + return validation.Range(r).DefaultMessage() +} + +// GetKey return the m.Key +func (r Range) GetKey() string { + return validation.Range(r).GetKey() +} + +// GetLimitValue return the limit value, Max +func (r Range) GetLimitValue() interface{} { + return validation.Range(r).GetLimitValue() +} + +// MinSize Requires an array or string to be at least a given length. +type MinSize validation.MinSize + +// IsSatisfied judge whether obj is valid +func (m MinSize) IsSatisfied(obj interface{}) bool { + return validation.MinSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MinSize error message +func (m MinSize) DefaultMessage() string { + return validation.MinSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MinSize) GetKey() string { + return validation.MinSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MinSize) GetLimitValue() interface{} { + return validation.MinSize(m).GetLimitValue() +} + +// MaxSize Requires an array or string to be at most a given length. +type MaxSize validation.MaxSize + +// IsSatisfied judge whether obj is valid +func (m MaxSize) IsSatisfied(obj interface{}) bool { + return validation.MaxSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MaxSize error message +func (m MaxSize) DefaultMessage() string { + return validation.MaxSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MaxSize) GetKey() string { + return validation.MaxSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MaxSize) GetLimitValue() interface{} { + return validation.MaxSize(m).GetLimitValue() +} + +// Length Requires an array or string to be exactly a given length. +type Length validation.Length + +// IsSatisfied judge whether obj is valid +func (l Length) IsSatisfied(obj interface{}) bool { + return validation.Length(l).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (l Length) DefaultMessage() string { + return validation.Length(l).DefaultMessage() +} + +// GetKey return the m.Key +func (l Length) GetKey() string { + return validation.Length(l).GetKey() +} + +// GetLimitValue return the limit value +func (l Length) GetLimitValue() interface{} { + return validation.Length(l).GetLimitValue() +} + +// Alpha check the alpha +type Alpha validation.Alpha + +// IsSatisfied judge whether obj is valid +func (a Alpha) IsSatisfied(obj interface{}) bool { + return validation.Alpha(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a Alpha) DefaultMessage() string { + return validation.Alpha(a).DefaultMessage() +} + +// GetKey return the m.Key +func (a Alpha) GetKey() string { + return validation.Alpha(a).GetKey() +} + +// GetLimitValue return the limit value +func (a Alpha) GetLimitValue() interface{} { + return validation.Alpha(a).GetLimitValue() +} + +// Numeric check number +type Numeric validation.Numeric + +// IsSatisfied judge whether obj is valid +func (n Numeric) IsSatisfied(obj interface{}) bool { + return validation.Numeric(n).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (n Numeric) DefaultMessage() string { + return validation.Numeric(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n Numeric) GetKey() string { + return validation.Numeric(n).GetKey() +} + +// GetLimitValue return the limit value +func (n Numeric) GetLimitValue() interface{} { + return validation.Numeric(n).GetLimitValue() +} + +// AlphaNumeric check alpha and number +type AlphaNumeric validation.AlphaNumeric + +// IsSatisfied judge whether obj is valid +func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { + return validation.AlphaNumeric(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a AlphaNumeric) DefaultMessage() string { + return validation.AlphaNumeric(a).DefaultMessage() +} + +// GetKey return the a.Key +func (a AlphaNumeric) GetKey() string { + return validation.AlphaNumeric(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaNumeric) GetLimitValue() interface{} { + return validation.AlphaNumeric(a).GetLimitValue() +} + +// Match Requires a string to match a given regex. +type Match validation.Match + +// IsSatisfied judge whether obj is valid +func (m Match) IsSatisfied(obj interface{}) bool { + return validation.Match(m).IsSatisfied(obj) +} + +// DefaultMessage return the default Match error message +func (m Match) DefaultMessage() string { + return validation.Match(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Match) GetKey() string { + return validation.Match(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Match) GetLimitValue() interface{} { + return validation.Match(m).GetLimitValue() +} + +// NoMatch Requires a string to not match a given regex. +type NoMatch validation.NoMatch + +// IsSatisfied judge whether obj is valid +func (n NoMatch) IsSatisfied(obj interface{}) bool { + return validation.NoMatch(n).IsSatisfied(obj) +} + +// DefaultMessage return the default NoMatch error message +func (n NoMatch) DefaultMessage() string { + return validation.NoMatch(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n NoMatch) GetKey() string { + return validation.NoMatch(n).GetKey() +} + +// GetLimitValue return the limit value +func (n NoMatch) GetLimitValue() interface{} { + return validation.NoMatch(n).GetLimitValue() +} + +// AlphaDash check not Alpha +type AlphaDash validation.AlphaDash + +// DefaultMessage return the default AlphaDash error message +func (a AlphaDash) DefaultMessage() string { + return validation.AlphaDash(a).DefaultMessage() +} + +// GetKey return the n.Key +func (a AlphaDash) GetKey() string { + return validation.AlphaDash(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaDash) GetLimitValue() interface{} { + return validation.AlphaDash(a).GetLimitValue() +} + +// Email check struct +type Email validation.Email + +// DefaultMessage return the default Email error message +func (e Email) DefaultMessage() string { + return validation.Email(e).DefaultMessage() +} + +// GetKey return the n.Key +func (e Email) GetKey() string { + return validation.Email(e).GetKey() +} + +// GetLimitValue return the limit value +func (e Email) GetLimitValue() interface{} { + return validation.Email(e).GetLimitValue() +} + +// IP check struct +type IP validation.IP + +// DefaultMessage return the default IP error message +func (i IP) DefaultMessage() string { + return validation.IP(i).DefaultMessage() +} + +// GetKey return the i.Key +func (i IP) GetKey() string { + return validation.IP(i).GetKey() +} + +// GetLimitValue return the limit value +func (i IP) GetLimitValue() interface{} { + return validation.IP(i).GetLimitValue() +} + +// Base64 check struct +type Base64 validation.Base64 + +// DefaultMessage return the default Base64 error message +func (b Base64) DefaultMessage() string { + return validation.Base64(b).DefaultMessage() +} + +// GetKey return the b.Key +func (b Base64) GetKey() string { + return validation.Base64(b).GetKey() +} + +// GetLimitValue return the limit value +func (b Base64) GetLimitValue() interface{} { + return validation.Base64(b).GetLimitValue() +} + +// Mobile check struct +type Mobile validation.Mobile + +// DefaultMessage return the default Mobile error message +func (m Mobile) DefaultMessage() string { + return validation.Mobile(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Mobile) GetKey() string { + return validation.Mobile(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Mobile) GetLimitValue() interface{} { + return validation.Mobile(m).GetLimitValue() +} + +// Tel check telephone struct +type Tel validation.Tel + +// DefaultMessage return the default Tel error message +func (t Tel) DefaultMessage() string { + return validation.Tel(t).DefaultMessage() +} + +// GetKey return the t.Key +func (t Tel) GetKey() string { + return validation.Tel(t).GetKey() +} + +// GetLimitValue return the limit value +func (t Tel) GetLimitValue() interface{} { + return validation.Tel(t).GetLimitValue() +} + +// Phone just for chinese telephone or mobile phone number +type Phone validation.Phone + +// IsSatisfied judge whether obj is valid +func (p Phone) IsSatisfied(obj interface{}) bool { + return validation.Phone(p).IsSatisfied(obj) +} + +// DefaultMessage return the default Phone error message +func (p Phone) DefaultMessage() string { + return validation.Phone(p).DefaultMessage() +} + +// GetKey return the p.Key +func (p Phone) GetKey() string { + return validation.Phone(p).GetKey() +} + +// GetLimitValue return the limit value +func (p Phone) GetLimitValue() interface{} { + return validation.Phone(p).GetLimitValue() +} + +// ZipCode check the zip struct +type ZipCode validation.ZipCode + +// DefaultMessage return the default Zip error message +func (z ZipCode) DefaultMessage() string { + return validation.ZipCode(z).DefaultMessage() +} + +// GetKey return the z.Key +func (z ZipCode) GetKey() string { + return validation.ZipCode(z).GetKey() +} + +// GetLimitValue return the limit value +func (z ZipCode) GetLimitValue() interface{} { + return validation.ZipCode(z).GetLimitValue() +} diff --git a/pkg/client/orm/cmd.go b/pkg/client/orm/cmd.go index e03fc0ee..b0661971 100644 --- a/pkg/client/orm/cmd.go +++ b/pkg/client/orm/cmd.go @@ -142,6 +142,12 @@ func (d *commandSyncDb) Run() error { } for i, mi := range modelCache.allOrdered() { + + if !isApplicableTableForDB(mi.addrField, d.al.Name) { + fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name) + continue + } + if tables[mi.table] { if !d.noInfo { fmt.Printf("table `%s` already exists, skip\n", mi.table) diff --git a/pkg/client/orm/db_alias.go b/pkg/client/orm/db_alias.go index 8a5cfb10..29e0904c 100644 --- a/pkg/client/orm/db_alias.go +++ b/pkg/client/orm/db_alias.go @@ -21,9 +21,6 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/infrastructure/utils" - lru "github.com/hashicorp/golang-lru" ) @@ -278,6 +275,7 @@ type alias struct { MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration + StmtCacheSize int DB *DB DbBaser dbBaser TZ *time.Location @@ -340,7 +338,7 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) if _, ok := dataBaseCache.get(aliasName); ok { return nil, existErr @@ -358,32 +356,35 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) return al, nil } -func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) { - kvs := utils.NewKVs(params...) +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { + + al := &alias{} + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + } + + for _, p := range params { + p(al) + } var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int) - if maxStmtCacheSize > 0 { - _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if al.StmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(al.StmtCacheSize) if errC != nil { return nil, errC } else { stmtCache = _stmtCache - stmtCacheSize = maxStmtCacheSize + stmtCacheSize = al.StmtCacheSize } } - al := new(alias) al.Name = aliasName al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: stmtCache, - stmtDecoratorsLimit: stmtCacheSize, - } + al.DB.stmtDecorators = stmtCache + al.DB.stmtDecoratorsLimit = stmtCacheSize if dr, ok := drivers[driverName]; ok { al.DbBaser = dbBasers[dr] @@ -399,31 +400,48 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV detectTZ(al) - kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxIdleConns(al, m) - } - }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxOpenConns(al, m) - } - }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { - if m, ok := value.(time.Duration); ok { - SetConnMaxLifetime(al, m) - } - }) - return al, nil } +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxOpenConns) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxIdleConns(maxIdleConns int) { + al.MaxIdleConns = maxIdleConns + al.DB.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxOpenConns(maxOpenConns int) { + al.MaxOpenConns = maxOpenConns + al.DB.DB.SetMaxOpenConns(maxOpenConns) +} + +func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) { + al.ConnMaxLifetime = lifeTime + al.DB.DB.SetConnMaxLifetime(lifeTime) +} + // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) error { +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error { _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...utils.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error { var ( err error db *sql.DB @@ -476,23 +494,6 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { return nil } -// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(al *alias, maxIdleConns int) { - al.MaxIdleConns = maxIdleConns - al.DB.DB.SetMaxIdleConns(maxIdleConns) -} - -// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(al *alias, maxOpenConns int) { - al.MaxOpenConns = maxOpenConns - al.DB.DB.SetMaxOpenConns(maxOpenConns) -} - -func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { - al.ConnMaxLifetime = lifeTime - al.DB.DB.SetConnMaxLifetime(lifeTime) -} - // GetDB Get *sql.DB from registered database by db alias name. // Use "default" as alias name if you not set. func GetDB(aliasNames ...string) (*sql.DB, error) { @@ -553,3 +554,33 @@ func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { } return cache, nil } + +type DBOption func(al *alias) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(maxIdleConn int) DBOption { + return func(al *alias) { + al.SetMaxIdleConns(maxIdleConn) + } +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(maxOpenConn int) DBOption { + return func(al *alias) { + al.SetMaxOpenConns(maxOpenConn) + } +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) DBOption { + return func(al *alias) { + al.SetConnMaxLifetime(v) + } +} + +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) DBOption { + return func(al *alias) { + al.StmtCacheSize = v + } +} diff --git a/pkg/client/orm/db_alias_test.go b/pkg/client/orm/db_alias_test.go index 0043ba76..6275cb2a 100644 --- a/pkg/client/orm/db_alias_test.go +++ b/pkg/client/orm/db_alias_test.go @@ -18,16 +18,14 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/stretchr/testify/assert" ) func TestRegisterDataBase(t *testing.T) { err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, - hints.MaxIdleConnections(20), - hints.MaxOpenConnections(300), - hints.ConnMaxLifetime(time.Minute)) + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") @@ -39,7 +37,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -49,7 +47,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -59,7 +57,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -69,7 +67,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/client/orm/hints/db_hints.go b/pkg/client/orm/hints/db_hints.go index 4d199312..7340bd07 100644 --- a/pkg/client/orm/hints/db_hints.go +++ b/pkg/client/orm/hints/db_hints.go @@ -15,20 +15,12 @@ package hints import ( - "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" ) const ( - //db level - KeyMaxIdleConnections = iota - KeyMaxOpenConnections - KeyConnMaxLifetime - KeyMaxStmtCacheSize - //query level - KeyForceIndex + KeyForceIndex = iota KeyUseIndex KeyIgnoreIndex KeyForUpdate @@ -57,26 +49,6 @@ func (s *Hint) GetValue() interface{} { var _ utils.KV = new(Hint) -// MaxIdleConnections return a hint about MaxIdleConnections -func MaxIdleConnections(v int) *Hint { - return NewHint(KeyMaxIdleConnections, v) -} - -// MaxOpenConnections return a hint about MaxOpenConnections -func MaxOpenConnections(v int) *Hint { - return NewHint(KeyMaxOpenConnections, v) -} - -// ConnMaxLifetime return a hint about ConnMaxLifetime -func ConnMaxLifetime(v time.Duration) *Hint { - return NewHint(KeyConnMaxLifetime, v) -} - -// MaxStmtCacheSize return a hint about MaxStmtCacheSize -func MaxStmtCacheSize(v int) *Hint { - return NewHint(KeyMaxStmtCacheSize, v) -} - // ForceIndex return a hint about ForceIndex func ForceIndex(indexes ...string) *Hint { return NewHint(KeyForceIndex, indexes) diff --git a/pkg/client/orm/hints/db_hints_test.go b/pkg/client/orm/hints/db_hints_test.go index 4e962a8f..510f9f16 100644 --- a/pkg/client/orm/hints/db_hints_test.go +++ b/pkg/client/orm/hints/db_hints_test.go @@ -48,34 +48,6 @@ func TestNewHint_float(t *testing.T) { assert.Equal(t, hint.GetValue(), value) } -func TestMaxOpenConnections(t *testing.T) { - i := 887423 - hint := MaxOpenConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections) -} - -func TestConnMaxLifetime(t *testing.T) { - i := time.Hour - hint := ConnMaxLifetime(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime) -} - -func TestMaxIdleConnections(t *testing.T) { - i := 42316 - hint := MaxIdleConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections) -} - -func TestMaxStmtCacheSize(t *testing.T) { - i := 94157 - hint := MaxStmtCacheSize(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize) -} - func TestForceIndex(t *testing.T) { s := []string{`f_index1`, `f_index2`, `f_index3`} hint := ForceIndex(s...) diff --git a/pkg/client/orm/migration/ddl.go b/pkg/client/orm/migration/ddl.go index c21352a8..e8b13212 100644 --- a/pkg/client/orm/migration/ddl.go +++ b/pkg/client/orm/migration/ddl.go @@ -31,7 +31,7 @@ type Unique struct { Columns []*Column } -//Column struct defines a single column of a table +// Column struct defines a single column of a table type Column struct { Name string Inc string @@ -84,7 +84,7 @@ func (m *Migration) NewCol(name string) *Column { return col } -//PriCol creates a new primary column and attaches it to m struct +// PriCol creates a new primary column and attaches it to m struct func (m *Migration) PriCol(name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -92,7 +92,7 @@ func (m *Migration) PriCol(name string) *Column { return col } -//UniCol creates / appends columns to specified unique key and attaches it to m struct +// UniCol creates / appends columns to specified unique key and attaches it to m struct func (m *Migration) UniCol(uni, name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -114,7 +114,7 @@ func (m *Migration) UniCol(uni, name string) *Column { return col } -//ForeignCol creates a new foreign column and returns the instance of column +// ForeignCol creates a new foreign column and returns the instance of column func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} @@ -123,25 +123,25 @@ func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreig return foreign } -//SetOnDelete sets the on delete of foreign +// SetOnDelete sets the on delete of foreign func (foreign *Foreign) SetOnDelete(del string) *Foreign { foreign.OnDelete = "ON DELETE" + del return foreign } -//SetOnUpdate sets the on update of foreign +// SetOnUpdate sets the on update of foreign func (foreign *Foreign) SetOnUpdate(update string) *Foreign { foreign.OnUpdate = "ON UPDATE" + update return foreign } -//Remove marks the columns to be removed. -//it allows reverse m to create the column. +// Remove marks the columns to be removed. +// it allows reverse m to create the column. func (c *Column) Remove() { c.remove = true } -//SetAuto enables auto_increment of column (can be used once) +// SetAuto enables auto_increment of column (can be used once) func (c *Column) SetAuto(inc bool) *Column { if inc { c.Inc = "auto_increment" @@ -149,7 +149,7 @@ func (c *Column) SetAuto(inc bool) *Column { return c } -//SetNullable sets the column to be null +// SetNullable sets the column to be null func (c *Column) SetNullable(null bool) *Column { if null { c.Null = "" @@ -160,13 +160,13 @@ func (c *Column) SetNullable(null bool) *Column { return c } -//SetDefault sets the default value, prepend with "DEFAULT " +// SetDefault sets the default value, prepend with "DEFAULT " func (c *Column) SetDefault(def string) *Column { c.Default = "DEFAULT " + def return c } -//SetUnsigned sets the column to be unsigned int +// SetUnsigned sets the column to be unsigned int func (c *Column) SetUnsigned(unsign bool) *Column { if unsign { c.Unsign = "UNSIGNED" @@ -174,13 +174,13 @@ func (c *Column) SetUnsigned(unsign bool) *Column { return c } -//SetDataType sets the dataType of the column +// SetDataType sets the dataType of the column func (c *Column) SetDataType(dataType string) *Column { c.DataType = dataType return c } -//SetOldNullable allows reverting to previous nullable on reverse ms +// SetOldNullable allows reverting to previous nullable on reverse ms func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { if null { c.OldNull = "" @@ -191,13 +191,13 @@ func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { return c } -//SetOldDefault allows reverting to previous default on reverse ms +// SetOldDefault allows reverting to previous default on reverse ms func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { c.OldDefault = def return c } -//SetOldUnsigned allows reverting to previous unsgined on reverse ms +// SetOldUnsigned allows reverting to previous unsgined on reverse ms func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { if unsign { c.OldUnsign = "UNSIGNED" @@ -205,19 +205,19 @@ func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { return c } -//SetOldDataType allows reverting to previous datatype on reverse ms +// SetOldDataType allows reverting to previous datatype on reverse ms func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { c.OldDataType = dataType return c } -//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) func (c *Column) SetPrimary(m *Migration) *Column { m.Primary = append(m.Primary, c) return c } -//AddColumnsToUnique adds the columns to Unique Struct +// AddColumnsToUnique adds the columns to Unique Struct func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { unique.Columns = append(unique.Columns, columns...) @@ -225,7 +225,7 @@ func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { return unique } -//AddColumns adds columns to m struct +// AddColumns adds columns to m struct func (m *Migration) AddColumns(columns ...*Column) *Migration { m.Columns = append(m.Columns, columns...) @@ -233,38 +233,38 @@ func (m *Migration) AddColumns(columns ...*Column) *Migration { return m } -//AddPrimary adds the column to primary in m struct +// AddPrimary adds the column to primary in m struct func (m *Migration) AddPrimary(primary *Column) *Migration { m.Primary = append(m.Primary, primary) return m } -//AddUnique adds the column to unique in m struct +// AddUnique adds the column to unique in m struct func (m *Migration) AddUnique(unique *Unique) *Migration { m.Uniques = append(m.Uniques, unique) return m } -//AddForeign adds the column to foreign in m struct +// AddForeign adds the column to foreign in m struct func (m *Migration) AddForeign(foreign *Foreign) *Migration { m.Foreigns = append(m.Foreigns, foreign) return m } -//AddIndex adds the column to index in m struct +// AddIndex adds the column to index in m struct func (m *Migration) AddIndex(index *Index) *Migration { m.Indexes = append(m.Indexes, index) return m } -//RenameColumn allows renaming of columns +// RenameColumn allows renaming of columns func (m *Migration) RenameColumn(from, to string) *RenameColumn { rename := &RenameColumn{OldName: from, NewName: to} m.Renames = append(m.Renames, rename) return rename } -//GetSQL returns the generated sql depending on ModifyType +// GetSQL returns the generated sql depending on ModifyType func (m *Migration) GetSQL() (sql string) { sql = "" switch m.ModifyType { diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index b38ea9e5..ea315c2b 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -442,7 +442,7 @@ func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { for _, mi := range mc.allOrdered() { queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) } - return queries,nil + return queries, nil } //getDbCreateSQL get database scheme creation sql queries diff --git a/pkg/client/orm/models_test.go b/pkg/client/orm/models_test.go index 81ba30df..f0044f6d 100644 --- a/pkg/client/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -22,8 +22,6 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -529,7 +527,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20)) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) diff --git a/pkg/client/orm/models_utils.go b/pkg/client/orm/models_utils.go index 6fca59a9..950ca243 100644 --- a/pkg/client/orm/models_utils.go +++ b/pkg/client/orm/models_utils.go @@ -107,6 +107,18 @@ func getTableUnique(val reflect.Value) [][]string { return nil } +// get whether the table needs to be created for the database alias +func isApplicableTableForDB(val reflect.Value, db string) bool { + fun := val.MethodByName("IsApplicableTableForDB") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) + if len(vals) > 0 && vals[0].Kind() == reflect.Bool { + return vals[0].Bool() + } + } + return true +} + // get snaked column name func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := col diff --git a/pkg/client/orm/models_utils_test.go b/pkg/client/orm/models_utils_test.go new file mode 100644 index 00000000..0a6995b3 --- /dev/null +++ b/pkg/client/orm/models_utils_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type NotApplicableModel struct { + Id int +} + +func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { + return db == "default" +} + +func Test_IsApplicableTableForDB(t *testing.T) { + assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) + assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) +} diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index a18dae3c..557c788c 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -311,9 +311,7 @@ func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (in return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { - _, fi, ind, qseter := o.queryRelated(md, name) - - qs := qseter.(*querySet) + _, fi, ind, qs := o.queryRelated(md, name) var relDepth int var limit, offset int64 @@ -377,7 +375,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s } // get QuerySeter for related models to md model -func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { +func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -616,7 +614,7 @@ func NewOrmUsingDB(aliasName string) Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...utils.KV) (Ormer, error) { +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...DBOption) (Ormer, error) { al, err := newAliasWithDb(aliasName, driverName, db, params...) if err != nil { return nil, err diff --git a/pkg/client/orm/orm_raw.go b/pkg/client/orm/orm_raw.go index c2539147..e11e97fa 100644 --- a/pkg/client/orm/orm_raw.go +++ b/pkg/client/orm/orm_raw.go @@ -330,6 +330,8 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return err } + structTagMap := make(map[reflect.StructTag]map[string]string) + defer rows.Close() if rows.Next() { @@ -396,7 +398,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { recursiveSetField(f) } - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + // thanks @Gazeboxu. + tags := structTagMap[fe.Tag] + if tags == nil { + _, tags = parseStructTag(fe.Tag.Get(defaultStructTagName)) + structTagMap[fe.Tag] = tags + } var col string if col = tags["column"]; col == "" { col = nameStrategyMap[nameStrategy](fe.Name) diff --git a/pkg/client/orm/types.go b/pkg/client/orm/types.go index cee570af..b9c444eb 100644 --- a/pkg/client/orm/types.go +++ b/pkg/client/orm/types.go @@ -75,6 +75,11 @@ type TableUniqueI interface { TableUnique() [][]string } +// IsApplicableTableForDB if return false, we won't create table to this db +type IsApplicableTableForDB interface { + IsApplicableTableForDB(db string) bool +} + // Driver define database driver type Driver interface { Name() string diff --git a/pkg/infrastructure/logs/alils/alils.go b/pkg/infrastructure/logs/alils/alils.go index 02e6f995..0689aae0 100644 --- a/pkg/infrastructure/logs/alils/alils.go +++ b/pkg/infrastructure/logs/alils/alils.go @@ -2,11 +2,14 @@ package alils import ( "encoding/json" + "fmt" "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/logs" "github.com/gogo/protobuf/proto" + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/logs" ) const ( @@ -27,6 +30,7 @@ type Config struct { Source string `json:"source"` Level int `json:"level"` FlushWhen int `json:"flush_when"` + Formatter string `json:"formatter"` } // aliLSWriter implements LoggerInterface. @@ -38,19 +42,23 @@ type aliLSWriter struct { groupMap map[string]*LogGroup lock *sync.Mutex Config + formatter logs.LogFormatter } // NewAliLS creates a new Logger func NewAliLS() logs.Logger { alils := new(aliLSWriter) alils.Level = logs.LevelTrace + alils.formatter = alils return alils } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string) (err error) { - - json.Unmarshal([]byte(jsonConfig), c) +func (c *aliLSWriter) Init(config string) error { + err := json.Unmarshal([]byte(config), c) + if err != nil { + return err + } if c.FlushWhen > CacheSize { c.FlushWhen = CacheSize @@ -63,11 +71,13 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { AccessKeySecret: c.KeySecret, } - c.store, err = prj.GetLogStore(c.LogStore) + store, err := prj.GetLogStore(c.LogStore) if err != nil { return err } + c.store = store + // Create default Log Group c.group = append(c.group, &LogGroup{ Topic: proto.String(""), @@ -97,9 +107,25 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { c.lock = &sync.Mutex{} + if len(c.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) + } + c.formatter = fmtr + } + return nil } +func (c *aliLSWriter) Format(lm *logs.LogMsg) string { + return lm.OldStyleFormat() +} + +func (c *aliLSWriter) SetFormatter(f logs.LogFormatter) { + c.formatter = f +} + // WriteMsg writes a message in connection. // If connection is down, try to re-connect. func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { @@ -117,20 +143,19 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { if len(strs) == 2 { pos := strings.LastIndex(strs[0], " ") topic = strs[0][pos+1 : len(strs[0])] - content = strs[0][0:pos] + strs[1] lg = c.groupMap[topic] } // send to empty Topic if lg == nil { - content = lm.Msg lg = c.group[0] } } else { - content = lm.Msg lg = c.group[0] } + content = c.formatter.Format(lm) + c1 := &LogContent{ Key: proto.String("msg"), Value: proto.String(content), @@ -150,7 +175,6 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { if len(lg.Logs) >= c.FlushWhen { c.flush(lg) } - return nil } diff --git a/pkg/infrastructure/logs/conn.go b/pkg/infrastructure/logs/conn.go index e0560fd9..1fd71be7 100644 --- a/pkg/infrastructure/logs/conn.go +++ b/pkg/infrastructure/logs/conn.go @@ -16,8 +16,11 @@ package logs import ( "encoding/json" + "fmt" "io" "net" + + "github.com/pkg/errors" ) // connWriter implements LoggerInterface. @@ -25,6 +28,8 @@ import ( type connWriter struct { lg *logWriter innerWriter io.WriteCloser + formatter LogFormatter + Formatter string `json:"formatter"` ReconnectOnMsg bool `json:"reconnectOnMsg"` Reconnect bool `json:"reconnect"` Net string `json:"net"` @@ -36,13 +41,30 @@ type connWriter struct { func NewConn() Logger { conn := new(connWriter) conn.Level = LevelTrace + conn.formatter = conn return conn } +func (c *connWriter) Format(lm *LogMsg) string { + return lm.OldStyleFormat() +} + // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string) error { - return json.Unmarshal([]byte(jsonConfig), c) +func (c *connWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) + } + c.formatter = fmtr + } + return res +} + +func (c *connWriter) SetFormatter(f LogFormatter) { + c.formatter = f } // WriteMsg writes message in connection. @@ -62,7 +84,9 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { defer c.innerWriter.Close() } - _, err := c.lg.writeln(lm) + msg := c.formatter.Format(lm) + + _, err := c.lg.writeln(msg) if err != nil { return err } diff --git a/pkg/infrastructure/logs/console.go b/pkg/infrastructure/logs/console.go index 024152aa..f99ef11b 100644 --- a/pkg/infrastructure/logs/console.go +++ b/pkg/infrastructure/logs/console.go @@ -16,9 +16,11 @@ package logs import ( "encoding/json" + "fmt" "os" "strings" + "github.com/pkg/errors" "github.com/shiena/ansicolor" ) @@ -47,9 +49,25 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + formatter LogFormatter + Formatter string `json:"formatter"` + Level int `json:"level"` + Colorful bool `json:"color"` // this filed is useful only when system's terminal supports color +} + +func (c *consoleWriter) Format(lm *LogMsg) string { + msg := lm.OldStyleFormat() + if c.Colorful { + msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) + } + h, _, _ := formatTimeHeader(lm.When) + bytes := append(append(h, msg...), '\n') + return string(bytes) +} + +func (c *consoleWriter) SetFormatter(f LogFormatter) { + c.formatter = f } // NewConsole creates ConsoleWriter returning as LoggerInterface. @@ -59,16 +77,27 @@ func NewConsole() Logger { Level: LevelDebug, Colorful: true, } + cw.formatter = cw return cw } // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string) error { - if len(jsonConfig) == 0 { +func (c *consoleWriter) Init(config string) error { + + if len(config) == 0 { return nil } - return json.Unmarshal([]byte(jsonConfig), c) + + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) + } + c.formatter = fmtr + } + return res } // WriteMsg writes message in console. @@ -76,10 +105,8 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if lm.Level > c.Level { return nil } - if c.Colorful { - lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) - } - c.lg.writeln(lm) + msg := c.formatter.Format(lm) + c.lg.writeln(msg) return nil } diff --git a/pkg/infrastructure/logs/es/es.go b/pkg/infrastructure/logs/es/es.go index c2631584..438a6da6 100644 --- a/pkg/infrastructure/logs/es/es.go +++ b/pkg/infrastructure/logs/es/es.go @@ -31,13 +31,34 @@ func NewES() logs.Logger { // import _ "github.com/astaxie/beego/logs/es" type esLogger struct { *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` + DSN string `json:"dsn"` + Level int `json:"level"` + formatter logs.LogFormatter + Formatter string `json:"formatter"` +} + +func (el *esLogger) Format(lm *logs.LogMsg) string { + + msg := lm.OldStyleFormat() + idx := LogDocument{ + Timestamp: lm.When.Format(time.RFC3339), + Msg: msg, + } + body, err := json.Marshal(idx) + if err != nil { + return msg + } + return string(body) +} + +func (el *esLogger) SetFormatter(f logs.LogFormatter) { + el.formatter = f } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), el) +func (el *esLogger) Init(config string) error { + + err := json.Unmarshal([]byte(config), el) if err != nil { return err } @@ -56,6 +77,13 @@ func (el *esLogger) Init(jsonconfig string) error { } el.Client = conn } + if len(el.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(el.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", el.Formatter)) + } + el.formatter = fmtr + } return nil } @@ -65,21 +93,14 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { return nil } - idx := LogDocument{ - Timestamp: lm.When.Format(time.RFC3339), - Msg: lm.Msg, - } + msg := el.formatter.Format(lm) - body, err := json.Marshal(idx) - if err != nil { - return err - } req := esapi.IndexRequest{ Index: fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()), DocumentType: "logs", - Body: strings.NewReader(string(body)), + Body: strings.NewReader(msg), } - _, err = req.Do(context.Background(), el.Client) + _, err := req.Do(context.Background(), el.Client) return err } diff --git a/pkg/infrastructure/logs/file.go b/pkg/infrastructure/logs/file.go index 23ea4b09..b01be357 100644 --- a/pkg/infrastructure/logs/file.go +++ b/pkg/infrastructure/logs/file.go @@ -69,6 +69,9 @@ type fileLogWriter struct { RotatePerm string `json:"rotateperm"` fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix + + formatter LogFormatter + Formatter string `json:"formatter"` } // newFileWriter creates a FileLogWriter returning as LoggerInterface. @@ -86,9 +89,21 @@ func newFileWriter() Logger { MaxFiles: 999, MaxSize: 1 << 28, } + w.formatter = w return w } +func (w *fileLogWriter) Format(lm *LogMsg) string { + msg := lm.OldStyleFormat() + hd, _, _ := formatTimeHeader(lm.When) + msg = fmt.Sprintf("%s %s\n", string(hd), msg) + return msg +} + +func (w *fileLogWriter) SetFormatter(f LogFormatter) { + w.formatter = f +} + // Init file logger with json config. // jsonConfig like: // { @@ -100,8 +115,9 @@ func newFileWriter() Logger { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string) error { - err := json.Unmarshal([]byte(jsonConfig), w) +func (w *fileLogWriter) Init(config string) error { + + err := json.Unmarshal([]byte(config), w) if err != nil { return err } @@ -113,6 +129,14 @@ func (w *fileLogWriter) Init(jsonConfig string) error { if w.suffix == "" { w.suffix = ".log" } + + if len(w.Formatter) > 0 { + fmtr, ok := GetFormatter(w.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", w.Formatter)) + } + w.formatter = fmtr + } err = w.startLogger() return err } @@ -130,13 +154,13 @@ func (w *fileLogWriter) startLogger() error { return w.initFd() } -func (w *fileLogWriter) needRotateDaily(size int, day int) bool { +func (w *fileLogWriter) needRotateDaily(day int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Daily && day != w.dailyOpenDate) } -func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { +func (w *fileLogWriter) needRotateHourly(hour int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Hourly && hour != w.hourlyOpenDate) @@ -148,23 +172,25 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { if lm.Level > w.Level { return nil } - hd, d, h := formatTimeHeader(lm.When) - lm.Msg = string(hd) + lm.Msg + "\n" + + _, d, h := formatTimeHeader(lm.When) + + msg := w.formatter.Format(lm) if w.Rotate { w.RLock() - if w.needRotateHourly(len(lm.Msg), h) { + if w.needRotateHourly(h) { w.RUnlock() w.Lock() - if w.needRotateHourly(len(lm.Msg), h) { + if w.needRotateHourly(h) { if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } } w.Unlock() - } else if w.needRotateDaily(len(lm.Msg), d) { + } else if w.needRotateDaily(d) { w.RUnlock() w.Lock() - if w.needRotateDaily(len(lm.Msg), d) { + if w.needRotateDaily(d) { if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -176,10 +202,10 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { } w.Lock() - _, err := w.fileWriter.Write([]byte(lm.Msg)) + _, err := w.fileWriter.Write([]byte(msg)) if err == nil { w.maxLinesCurLines++ - w.maxSizeCurSize += len(lm.Msg) + w.maxSizeCurSize += len(msg) } w.Unlock() return err @@ -236,7 +262,7 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateDaily(0, time.Now().Day()) { + if w.needRotateDaily(time.Now().Day()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -251,7 +277,7 @@ func (w *fileLogWriter) hourlyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateHourly(0, time.Now().Hour()) { + if w.needRotateHourly(time.Now().Hour()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -302,7 +328,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { _, err = os.Lstat(w.Filename) if err != nil { - //even if the file is not exist or other ,we should RESTART the logger + // even if the file is not exist or other ,we should RESTART the logger goto RESTART_LOGGER } diff --git a/pkg/infrastructure/logs/file_test.go b/pkg/infrastructure/logs/file_test.go index 7f2a3590..494d0a9e 100644 --- a/pkg/infrastructure/logs/file_test.go +++ b/pkg/infrastructure/logs/file_test.go @@ -268,6 +268,7 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { Perm: "0660", RotatePerm: "0440", } + fw.formatter = fw if daily { fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) @@ -308,6 +309,8 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { Perm: "0660", RotatePerm: "0440", } + fw.formatter = fw + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) fw.dailyOpenDate = fw.dailyOpenTime.Day() @@ -340,6 +343,8 @@ func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { Perm: "0660", RotatePerm: "0440", } + + fw.formatter = fw fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() diff --git a/pkg/infrastructure/logs/formatter.go b/pkg/infrastructure/logs/formatter.go new file mode 100644 index 00000000..b2599f2d --- /dev/null +++ b/pkg/infrastructure/logs/formatter.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +var formatterMap = make(map[string]LogFormatter, 4) + +type LogFormatter interface { + Format(lm *LogMsg) string +} + +// RegisterFormatter register an formatter. Usually you should use this to extend your custom formatter +// for example: +// RegisterFormatter("my-fmt", &MyFormatter{}) +// logs.SetFormatter(Console, `{"formatter": "my-fmt"}`) +func RegisterFormatter(name string, fmtr LogFormatter) { + formatterMap[name] = fmtr +} + +func GetFormatter(name string) (LogFormatter, bool) { + res, ok := formatterMap[name] + return res, ok +} diff --git a/pkg/infrastructure/logs/jianliao.go b/pkg/infrastructure/logs/jianliao.go index 0e7cfab4..9757a7d5 100644 --- a/pkg/infrastructure/logs/jianliao.go +++ b/pkg/infrastructure/logs/jianliao.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/url" + + "github.com/pkg/errors" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -15,16 +17,38 @@ type JLWriter struct { RedirectURL string `json:"redirecturl,omitempty"` ImageURL string `json:"imageurl,omitempty"` Level int `json:"level"` + + formatter LogFormatter + Formatter string `json:"formatter"` } // newJLWriter creates jiaoliao writer. func newJLWriter() Logger { - return &JLWriter{Level: LevelTrace} + res := &JLWriter{Level: LevelTrace} + res.formatter = res + return res } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) +func (s *JLWriter) Init(config string) error { + + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) + } + s.formatter = fmtr + } + return res +} + +func (s *JLWriter) Format(lm *LogMsg) string { + return lm.OldStyleFormat() +} + +func (s *JLWriter) SetFormatter(f LogFormatter) { + s.formatter = f } // WriteMsg writes message in smtp writer. @@ -34,7 +58,7 @@ func (s *JLWriter) WriteMsg(lm *LogMsg) error { return nil } - text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) + text := s.formatter.Format(lm) form := url.Values{} form.Add("authorName", s.AuthorName) diff --git a/pkg/infrastructure/logs/log.go b/pkg/infrastructure/logs/log.go index 37421625..480cecab 100644 --- a/pkg/infrastructure/logs/log.go +++ b/pkg/infrastructure/logs/log.go @@ -42,6 +42,8 @@ import ( "strings" "sync" "time" + + "github.com/pkg/errors" ) // RFC5424 log message levels. @@ -88,6 +90,7 @@ type Logger interface { WriteMsg(lm *LogMsg) error Destroy() Flush() + SetFormatter(f LogFormatter) } var adapters = make(map[string]newLoggerFunc) @@ -122,6 +125,7 @@ type BeeLogger struct { signalChan chan string wg sync.WaitGroup outputs []*nameLogger + globalFormatter string } const defaultAsyncMsgLen = 1e3 @@ -132,11 +136,15 @@ type nameLogger struct { } type LogMsg struct { - Level int - Msg string - When time.Time - FilePath string - LineNumber int + Level int + Msg string + When time.Time + FilePath string + LineNumber int + Args []interface{} + Prefix string + enableFullFilePath bool + enableFuncCallDepth bool } var logMsgPool *sync.Pool @@ -179,6 +187,27 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { return bl } +// OldStyleFormat you should never invoke this +func (lm *LogMsg) OldStyleFormat() string { + msg := lm.Msg + + if len(lm.Args) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, lm.Args...) + } + + msg = lm.Prefix + " " + msg + + if lm.enableFuncCallDepth { + if !lm.enableFullFilePath { + _, lm.FilePath = path.Split(lm.FilePath) + } + msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, msg) + } + + msg = levelPrefix[lm.Level] + " " + msg + return msg +} + // SetLogger provides a given logger adapter into BeeLogger with config string. // config must in in JSON format like {"interval":360}} func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { @@ -195,7 +224,18 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } lg := logAdapter() + + // Global formatter overrides the default set formatter + if len(bl.globalFormatter) > 0 { + fmtr, ok := GetFormatter(bl.globalFormatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", bl.globalFormatter)) + } + lg.SetFormatter(fmtr) + } + err := lg.Init(config) + if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -265,46 +305,34 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { return 0, err } -func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { +func (bl *BeeLogger) writeMsg(lm *LogMsg) error { if !bl.init { bl.lock.Lock() bl.setLogger(AdapterConsole) bl.lock.Unlock() } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) - } - - lm.Msg = bl.prefix + " " + lm.Msg - var ( file string line int ok bool ) - if bl.enableFuncCallDepth { - _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) - if !ok { - file = "???" - line = 0 - } - - if !bl.enableFullFilePath { - _, file = path.Split(file) - } - lm.FilePath = file - lm.LineNumber = line - lm.Msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, lm.Msg) + _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) + if !ok { + file = "???" + line = 0 } + lm.FilePath = file + lm.LineNumber = line - //set level info in front of filename info + lm.enableFullFilePath = bl.enableFullFilePath + lm.enableFuncCallDepth = bl.enableFuncCallDepth + + // set level info in front of filename info if lm.Level == levelLoggerImpl { // set to emergency to ensure all log will be print out correctly lm.Level = LevelEmergency - } else { - lm.Msg = levelPrefix[lm.Level] + " " + lm.Msg } if bl.asynchronous { @@ -312,6 +340,10 @@ func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { logM.Level = lm.Level logM.Msg = lm.Msg logM.When = lm.When + logM.Args = lm.Args + logM.FilePath = lm.FilePath + logM.LineNumber = lm.LineNumber + logM.Prefix = lm.Prefix if bl.outputs != nil { bl.msgChan <- lm } else { @@ -382,6 +414,17 @@ func (bl *BeeLogger) startLogger() { } } +func (bl *BeeLogger) setGlobalFormatter(fmtter string) error { + bl.globalFormatter = fmtter + return nil +} + +// SetGlobalFormatter sets the global formatter for all log adapters +// don't forget to register the formatter by invoking RegisterFormatter +func SetGlobalFormatter(fmtter string) error { + return beeLogger.setGlobalFormatter(fmtter) +} + // Emergency Log EMERGENCY level message. func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { @@ -410,11 +453,8 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) { Level: LevelAlert, Msg: format, When: time.Now(), + Args: v, } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) - } - bl.writeMsg(lm) } @@ -427,9 +467,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) { Level: LevelCritical, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -444,9 +482,7 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { Level: LevelError, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -461,9 +497,7 @@ func (bl *BeeLogger) Warning(format string, v ...interface{}) { Level: LevelWarn, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -478,9 +512,7 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { Level: LevelNotice, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -495,9 +527,7 @@ func (bl *BeeLogger) Informational(format string, v ...interface{}) { Level: LevelInfo, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -512,9 +542,7 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) { Level: LevelDebug, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -530,9 +558,7 @@ func (bl *BeeLogger) Warn(format string, v ...interface{}) { Level: LevelWarn, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -548,9 +574,7 @@ func (bl *BeeLogger) Info(format string, v ...interface{}) { Level: LevelInfo, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -566,9 +590,7 @@ func (bl *BeeLogger) Trace(format string, v ...interface{}) { Level: LevelDebug, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -777,9 +799,9 @@ func formatLog(f interface{}, v ...interface{}) string { return msg } if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { - //format string + // format string } else { - //do not contain format char + // do not contain format char msg += strings.Repeat(" %v", len(v)) } default: diff --git a/pkg/infrastructure/logs/logger.go b/pkg/infrastructure/logs/logger.go index 721c8dc1..d8b334d4 100644 --- a/pkg/infrastructure/logs/logger.go +++ b/pkg/infrastructure/logs/logger.go @@ -30,10 +30,10 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) writeln(lm *LogMsg) (int, error) { +func (lg *logWriter) writeln(msg string) (int, error) { lg.Lock() - h, _, _ := formatTimeHeader(lm.When) - n, err := lg.writer.Write(append(append(h, lm.Msg...), '\n')) + msg += "\n" + n, err := lg.writer.Write([]byte(msg)) lg.Unlock() return n, err } diff --git a/pkg/infrastructure/logs/multifile.go b/pkg/infrastructure/logs/multifile.go index 1cd9e9f8..79178211 100644 --- a/pkg/infrastructure/logs/multifile.go +++ b/pkg/infrastructure/logs/multifile.go @@ -45,6 +45,7 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // } func (f *multiFileLogWriter) Init(config string) error { + writer := newFileWriter().(*fileLogWriter) err := writer.Init(config) if err != nil { @@ -53,11 +54,17 @@ func (f *multiFileLogWriter) Init(config string) error { f.fullLogWriter = writer f.writers[LevelDebug+1] = writer - //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(config), f) + // unmarshal "separate" field to f.Separate + err = json.Unmarshal([]byte(config), f) + if err != nil { + return err + } jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(config), &jsonMap) + err = json.Unmarshal([]byte(config), &jsonMap) + if err != nil { + return err + } for i := LevelEmergency; i < LevelDebug+1; i++ { for _, v := range f.Separate { @@ -74,10 +81,17 @@ func (f *multiFileLogWriter) Init(config string) error { } } } - return nil } +func (f *multiFileLogWriter) Format(lm *LogMsg) string { + return lm.OldStyleFormat() +} + +func (f *multiFileLogWriter) SetFormatter(fmt LogFormatter) { + f.fullLogWriter.SetFormatter(f) +} + func (f *multiFileLogWriter) Destroy() { for i := 0; i < len(f.writers); i++ { if f.writers[i] != nil { @@ -110,7 +124,8 @@ func (f *multiFileLogWriter) Flush() { // newFilesWriter create a FileLogWriter returning as LoggerInterface. func newFilesWriter() Logger { - return &multiFileLogWriter{} + res := &multiFileLogWriter{} + return res } func init() { diff --git a/pkg/infrastructure/logs/slack.go b/pkg/infrastructure/logs/slack.go index dad4f4ea..b6e2f170 100644 --- a/pkg/infrastructure/logs/slack.go +++ b/pkg/infrastructure/logs/slack.go @@ -5,22 +5,47 @@ import ( "fmt" "net/http" "net/url" + + "github.com/pkg/errors" ) // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook type SLACKWriter struct { WebhookURL string `json:"webhookurl"` Level int `json:"level"` + formatter LogFormatter + Formatter string `json:"formatter"` } // newSLACKWriter creates jiaoliao writer. func newSLACKWriter() Logger { - return &SLACKWriter{Level: LevelTrace} + res := &SLACKWriter{Level: LevelTrace} + res.formatter = res + return res +} + +func (s *SLACKWriter) Format(lm *LogMsg) string { + text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), lm.OldStyleFormat()) + return text +} + +func (s *SLACKWriter) SetFormatter(f LogFormatter) { + s.formatter = f } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) +func (s *SLACKWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) + + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) + } + s.formatter = fmtr + } + + return res } // WriteMsg write message in smtp writer. @@ -29,11 +54,9 @@ func (s *SLACKWriter) WriteMsg(lm *LogMsg) error { if lm.Level > s.Level { return nil } - - text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) - + msg := s.Format(lm) form := url.Values{} - form.Add("payload", text) + form.Add("payload", msg) resp, err := http.PostForm(s.WebhookURL, form) if err != nil { diff --git a/pkg/infrastructure/logs/smtp.go b/pkg/infrastructure/logs/smtp.go index 0d2b3c29..40891a7c 100644 --- a/pkg/infrastructure/logs/smtp.go +++ b/pkg/infrastructure/logs/smtp.go @@ -21,6 +21,8 @@ import ( "net" "net/smtp" "strings" + + "github.com/pkg/errors" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -32,11 +34,15 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` + formatter LogFormatter + Formatter string `json:"formatter"` } // NewSMTPWriter creates the smtp writer. func newSMTPWriter() Logger { - return &SMTPWriter{Level: LevelTrace} + res := &SMTPWriter{Level: LevelTrace} + res.formatter = res + return res } // Init smtp writer with json config. @@ -50,8 +56,16 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) +func (s *SMTPWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) + } + s.formatter = fmtr + } + return res } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { @@ -66,6 +80,10 @@ func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { ) } +func (s *SMTPWriter) SetFormatter(f LogFormatter) { + s.formatter = f +} + func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { client, err := smtp.Dial(hostAddressWithPort) if err != nil { @@ -114,6 +132,10 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return client.Quit() } +func (s *SMTPWriter) Format(lm *LogMsg) string { + return lm.OldStyleFormat() +} + // WriteMsg writes message in smtp writer. // Sends an email with subject and only this message. func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { @@ -126,11 +148,13 @@ func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { // Set up authentication information. auth := s.getSMTPAuth(hp[0]) + msg := s.Format(lm) + // Connect to the server, authenticate, set the sender and recipient, // and send the email all in one step. contentType := "Content-Type: text/plain" + "; charset=UTF-8" mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + lm.Msg) + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + msg) return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) } diff --git a/pkg/infrastructure/session/sess_cookie.go b/pkg/infrastructure/session/sess_cookie.go index ffb19fb7..649f6510 100644 --- a/pkg/infrastructure/session/sess_cookie.go +++ b/pkg/infrastructure/session/sess_cookie.go @@ -172,7 +172,7 @@ func (pder *CookieProvider) SessionAll(context.Context) int { } // SessionUpdate Implement method, no used. -func (pder *CookieProvider) SessionUpdate(sid string) error { +func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error { return nil } diff --git a/pkg/infrastructure/session/sess_mem.go b/pkg/infrastructure/session/sess_mem.go index 9a27c331..27e24c73 100644 --- a/pkg/infrastructure/session/sess_mem.go +++ b/pkg/infrastructure/session/sess_mem.go @@ -96,7 +96,7 @@ func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, sav func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) + go pder.SessionUpdate(nil, sid) pder.lock.RUnlock() return element.Value.(*MemSessionStore), nil } @@ -123,7 +123,7 @@ func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, er func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { - go pder.SessionUpdate(oldsid) + go pder.SessionUpdate(nil, oldsid) pder.lock.RUnlock() pder.lock.Lock() element.Value.(*MemSessionStore).sid = sid @@ -181,7 +181,7 @@ func (pder *MemProvider) SessionAll(context.Context) int { } // SessionUpdate expand time of session store by id in memory session -func (pder *MemProvider) SessionUpdate(sid string) error { +func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { diff --git a/pkg/server/web/app.go b/pkg/server/web/app.go index e61084a5..7511c7fe 100644 --- a/pkg/server/web/app.go +++ b/pkg/server/web/app.go @@ -199,7 +199,7 @@ func (app *App) Run(mws ...MiddleWare) { pool.AppendCertsFromPEM(data) app.Server.TLSConfig = &tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: tls.ClientAuthType(BConfig.Listen.ClientAuth), } } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { @@ -492,15 +492,15 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // The pos means action constant including // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) +func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, opts...) return BeeApp } // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. // the filter's behavior is like stack -func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { - BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) +func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, opts...) return BeeApp } diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index bf8db30e..6e69a2fb 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -16,6 +16,7 @@ package web import ( context2 "context" + "crypto/tls" "fmt" "os" "path/filepath" @@ -72,6 +73,7 @@ type Listen struct { AdminPort int EnableFcgi bool EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O + ClientAuth int } // WebConfig holds web related config @@ -234,6 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, + ClientAuth: int(tls.RequireAndVerifyClientCert), }, WebConfig: WebConfig{ AutoRender: true, diff --git a/pkg/server/web/context/input.go b/pkg/server/web/context/input.go index a6fec774..f8657f84 100644 --- a/pkg/server/web/context/input.go +++ b/pkg/server/web/context/input.go @@ -89,7 +89,7 @@ func (input *BeegoInput) URI() string { // URL returns the request url path (without query, string and fragment). func (input *BeegoInput) URL() string { - return input.Context.Request.URL.EscapedPath() + return input.Context.Request.URL.Path } // Site returns the base site url as scheme://domain type. diff --git a/pkg/server/web/filter.go b/pkg/server/web/filter.go index 8d3acb24..9aab48d6 100644 --- a/pkg/server/web/filter.go +++ b/pkg/server/web/filter.go @@ -43,24 +43,27 @@ type FilterRouter struct { // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter { +func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + } + + fos := &filterOpts{ returnOnOutput: true, } - if !routerCaseSensitive { + + for _, o := range opts { + o(fos) + } + + if !fos.routerCaseSensitive { mr.pattern = strings.ToLower(pattern) } - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } + mr.returnOnOutput = fos.returnOnOutput + mr.resetParams = fos.resetParams mr.tree.AddRouter(pattern, true) return mr } @@ -103,3 +106,29 @@ func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { } return false } + +type filterOpts struct { + returnOnOutput bool + resetParams bool + routerCaseSensitive bool +} + +type FilterOpt func(opts *filterOpts) + +func WithReturnOnOutput(ret bool) FilterOpt { + return func(opts *filterOpts) { + opts.returnOnOutput = ret + } +} + +func WithResetParams(reset bool) FilterOpt { + return func(opts *filterOpts) { + opts.resetParams = reset + } +} + +func WithCaseSensitive(sensitive bool) FilterOpt { + return func(opts *filterOpts) { + opts.routerCaseSensitive = sensitive + } +} diff --git a/pkg/server/web/filter/apiauth/apiauth.go b/pkg/server/web/filter/apiauth/apiauth.go index ba56030b..8944db63 100644 --- a/pkg/server/web/filter/apiauth/apiauth.go +++ b/pkg/server/web/filter/apiauth/apiauth.go @@ -83,11 +83,6 @@ func APIBasicAuth(appid, appkey string) web.FilterFunc { return APISecretAuth(ft, 300) } -// APIBasicAuth calls APIBasicAuth for previous callers -func APIBaiscAuth(appid, appkey string) web.FilterFunc { - return APIBasicAuth(appid, appkey) -} - // APISecretAuth uses AppIdToAppSecret verify and func APISecretAuth(f AppIDToAppSecret, timeout int) web.FilterFunc { return func(ctx *context.Context) { diff --git a/pkg/server/web/namespace.go b/pkg/server/web/namespace.go index e59f38c5..a792aa60 100644 --- a/pkg/server/web/namespace.go +++ b/pkg/server/web/namespace.go @@ -91,7 +91,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { a = FinishRouter } for _, f := range filter { - n.handlers.InsertFilter("*", a, f) + n.handlers.InsertFilter("*", a, f, WithReturnOnOutput(true)) } return n } diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index c3eddd29..3dd19a6f 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -148,7 +148,7 @@ func NewControllerRegister() *ControllerRegister { }, }, } - res.chainRoot = newFilterRouter("/*", false, res.serveHttp) + res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } @@ -262,7 +262,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { for _, f := range a.Filters { - p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) @@ -452,8 +452,9 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...) +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) error { + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + mr := newFilterRouter(pattern, filter, opts...) return p.insertFilterRouter(pos, mr) } @@ -468,10 +469,11 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // // do something // } // } -func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params ...bool) { +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { root := p.chainRoot filterFunc := chain(root.filterFunc) - p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) p.chainRoot.next = root } diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index 14ad1484..2863da3a 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -212,6 +212,23 @@ func TestAutoExtFunc(t *testing.T) { } } +func TestEscape(t *testing.T) { + + r, _ := http.NewRequest("GET", "/search/%E4%BD%A0%E5%A5%BD", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Get("/search/:keyword(.+)", func(ctx *context.Context) { + value := ctx.Input.Param(":keyword") + ctx.Output.Body([]byte(value)) + }) + handler.ServeHTTP(w, r) + str := w.Body.String() + if str != "你好" { + t.Errorf("incorrect, %s", str) + } +} + func TestRouteOk(t *testing.T) { r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) @@ -423,7 +440,7 @@ func TestInsertFilter(t *testing.T) { testName := "TestInsertFilter" mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true)) if !mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing no variadic params should set returnOnOutput to true", @@ -436,7 +453,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(false)) if mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing false as 1st variadic param should set returnOnOutput to false", @@ -444,7 +461,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true), WithResetParams(true)) if !mux.filters[BeforeRouter][0].resetParams { t.Errorf( "%s: passing true as 2nd variadic param should set resetParams to true", @@ -461,7 +478,7 @@ func TestParamResetFilter(t *testing.T) { mux := NewControllerRegister() - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + mux.InsertFilter("*", BeforeExec, beegoResetParams, WithReturnOnOutput(true), WithResetParams(true)) mux.Get(route, beegoHandleResetParams) @@ -514,8 +531,8 @@ func TestFilterBeforeExec(t *testing.T) { url := "/beforeExec" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -542,7 +559,7 @@ func TestFilterAfterExec(t *testing.T) { mux := NewControllerRegister() mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) @@ -570,10 +587,10 @@ func TestFilterFinishRouter(t *testing.T) { url := "/finishRouter" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -604,7 +621,7 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { url := "/finishRouterMultiFirstOnly" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) mux.Get(url, beegoFilterFunc) @@ -631,8 +648,8 @@ func TestFilterFinishRouterMulti(t *testing.T) { url := "/finishRouterMulti" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) diff --git a/pkg/server/web/tree.go b/pkg/server/web/tree.go index 7213a0c6..55f68076 100644 --- a/pkg/server/web/tree.go +++ b/pkg/server/web/tree.go @@ -33,13 +33,13 @@ var ( // wildcard stores params // leaves store the endpoint information type Tree struct { - //prefix set for static router + // prefix set for static router prefix string - //search fix route first + // search fix route first fixrouters []*Tree - //if set, failure to match fixrouters search then search wildcard + // if set, failure to match fixrouters search then search wildcard wildcard *Tree - //if set, failure to match wildcard search + // if set, failure to match wildcard search leaves []*leafInfo } @@ -69,13 +69,13 @@ func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg st filterTreeWithPrefix(tree, wildcards, reg) } } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -222,13 +222,13 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, t.addseg(segments[1:], route, wildcards, reg) params = params[1:] } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -393,7 +393,7 @@ type leafInfo struct { } func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { - //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) + // fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) if leaf.regexps == nil { if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path return true @@ -500,7 +500,7 @@ func splitSegment(key string) (bool, []string, string) { continue } if start { - //:id:int and :name:string + // :id:int and :name:string if v == ':' { if len(key) >= i+4 { if key[i+1:i+4] == "int" { diff --git a/pkg/task/task.go b/pkg/task/task.go index e2962000..bcadb956 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -83,7 +83,7 @@ type Schedule struct { } // TaskFunc task func type -type TaskFunc func() error +type TaskFunc func(ctx context.Context) error // Tasker task interface type Tasker interface { @@ -148,8 +148,8 @@ func (t *Task) GetStatus(context.Context) string { } // Run run all tasks -func (t *Task) Run(context.Context) error { - err := t.DoFunc() +func (t *Task) Run(ctx context.Context) error { + err := t.DoFunc(ctx) if err != nil { index := t.errCnt % t.ErrLimit t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 9f73ce46..488729dc 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -15,6 +15,7 @@ package task import ( + "context" "errors" "fmt" "sync" @@ -25,7 +26,10 @@ import ( ) func TestParse(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) err := tk.Run(nil) if err != nil { t.Fatal(err) @@ -39,9 +43,9 @@ func TestParse(t *testing.T) { func TestSpec(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(2) - tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) - tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) - tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func(ctx context.Context) error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func(ctx context.Context) error { fmt.Println("tk3"); wg.Done(); return nil }) AddTask("tk1", tk1) AddTask("tk2", tk2) @@ -58,7 +62,7 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 - task := func() error { + task := func(ctx context.Context) error { cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt))