diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go index 64280a7b..c1046c79 100644 --- a/pkg/adapter/app.go +++ b/pkg/adapter/app.go @@ -255,7 +255,8 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + opts := oldToNewFilterOpts(params) return (*App)(web.InsertFilter(pattern, pos, func(ctx *context.Context) { filter((*context2.Context)(ctx)) - }, params...)) + }, opts...)) } diff --git a/pkg/adapter/cache/cache_test.go b/pkg/adapter/cache/cache_test.go deleted file mode 100644 index 470c0a43..00000000 --- a/pkg/adapter/cache/cache_test.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "os" - "sync" - "testing" - "time" -) - -func TestCacheIncr(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - //timeoutDuration := 10 * time.Second - - bm.Put("edwardhey", 0, time.Second*20) - wg := sync.WaitGroup{} - wg.Add(10) - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - bm.Incr("edwardhey") - }() - } - wg.Wait() - if bm.Get("edwardhey").(int) != 10 { - t.Error("Incr err") - } -} - -func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - time.Sleep(30 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } -} - -func TestFileCache(t *testing.T) { - bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - - os.RemoveAll("cache") -} diff --git a/pkg/adapter/flash.go b/pkg/adapter/flash.go index e5e1c187..02e75ed6 100644 --- a/pkg/adapter/flash.go +++ b/pkg/adapter/flash.go @@ -28,7 +28,7 @@ func NewFlash() *FlashData { // Set message to flash func (fd *FlashData) Set(key string, msg string, args ...interface{}) { - (*web.FlashData)(fd).Set(key, msg, args) + (*web.FlashData)(fd).Set(key, msg, args...) } // Success writes success message to flash. diff --git a/pkg/adapter/metric/prometheus_test.go b/pkg/adapter/metric/prometheus_test.go index d82a6dec..87286e02 100644 --- a/pkg/adapter/metric/prometheus_test.go +++ b/pkg/adapter/metric/prometheus_test.go @@ -22,7 +22,7 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/adapter/context" ) func TestPrometheusMiddleWare(t *testing.T) { diff --git a/pkg/adapter/plugins/cors/cors_test.go b/pkg/adapter/plugins/cors/cors_test.go deleted file mode 100644 index 34039143..00000000 --- a/pkg/adapter/plugins/cors/cors_test.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cors - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header -type HTTPHeaderGuardRecorder struct { - *httptest.ResponseRecorder - savedHeaderMap http.Header -} - -// NewRecorder return HttpHeaderGuardRecorder -func NewRecorder() *HTTPHeaderGuardRecorder { - return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} -} - -func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { - gr.ResponseRecorder.WriteHeader(code) - gr.savedHeaderMap = gr.ResponseRecorder.Header() -} - -func (gr *HTTPHeaderGuardRecorder) Header() http.Header { - if gr.savedHeaderMap != nil { - // headers were written. clone so we don't get updates - clone := make(http.Header) - for k, v := range gr.savedHeaderMap { - clone[k] = v - } - return clone - } - return gr.ResponseRecorder.Header() -} - -func Test_AllowAll(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { - t.Errorf("Allow-Origin header should be *") - } -} - -func Test_AllowRegexMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://bar.foo.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != origin { - t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) - } -} - -func Test_AllowRegexNoMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://ww.foo.com.evil.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != "" { - t.Errorf("Allow-Origin header should not exist, found %v", headerValue) - } -} - -func Test_OtherHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - ExposeHeaders: []string{"Content-Length", "Hello"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) - methodsVal := recorder.HeaderMap.Get(headerAllowMethods) - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) - maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) - - if credentialsVal != "true" { - t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) - } - - if methodsVal != "PATCH,GET" { - t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) - } - - if headersVal != "Origin,X-whatever" { - t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) - } - - if exposedHeadersVal != "Content-Length,Hello" { - t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) - } - - if maxAgeVal != "300" { - t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) - } -} - -func Test_DefaultAllowHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - if headersVal != "Origin,Accept,Content-Type,Authorization" { - t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) - } -} - -func Test_Preflight(t *testing.T) { - recorder := NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowMethods: []string{"PUT", "PATCH"}, - AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, - })) - - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - r, _ := http.NewRequest("OPTIONS", "/foo", nil) - r.Header.Add(headerRequestMethod, "PUT") - r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") - handler.ServeHTTP(recorder, r) - - headers := recorder.Header() - methodsVal := headers.Get(headerAllowMethods) - headersVal := headers.Get(headerAllowHeaders) - originVal := headers.Get(headerAllowOrigin) - - if methodsVal != "PUT,PATCH" { - t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) - } - - if !strings.Contains(headersVal, "X-whatever") { - t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) - } - - if !strings.Contains(headersVal, "x-casesensitive") { - t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) - } - - if originVal != "*" { - t.Errorf("Allow-Origin is expected to be *, found %v", originVal) - } - - if recorder.Code != http.StatusOK { - t.Errorf("Status code is expected to be 200, found %d", recorder.Code) - } -} - -func Benchmark_WithoutCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} - -func Benchmark_WithCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} diff --git a/pkg/adapter/router.go b/pkg/adapter/router.go index 5a36fbee..8e8d9fdb 100644 --- a/pkg/adapter/router.go +++ b/pkg/adapter/router.go @@ -249,6 +249,9 @@ func oldToNewFilterOpts(params []bool) []web.FilterOpt { opts := make([]web.FilterOpt, 0, 4) if len(params) > 0 { opts = append(opts, web.WithReturnOnOutput(params[0])) + } else { + // the default value should be true + opts = append(opts, web.WithReturnOnOutput(true)) } if len(params) > 1 { opts = append(opts, web.WithResetParams(params[1])) diff --git a/pkg/adapter/utils/file_test.go b/pkg/adapter/utils/file_test.go deleted file mode 100644 index b2644157..00000000 --- a/pkg/adapter/utils/file_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "path/filepath" - "reflect" - "testing" -) - -var noExistedFile = "/tmp/not_existed_file" - -func TestSelfPath(t *testing.T) { - path := SelfPath() - if path == "" { - t.Error("path cannot be empty") - } - t.Logf("SelfPath: %s", path) -} - -func TestSelfDir(t *testing.T) { - dir := SelfDir() - t.Logf("SelfDir: %s", dir) -} - -func TestFileExists(t *testing.T) { - if !FileExists("./file.go") { - t.Errorf("./file.go should exists, but it didn't") - } - - if FileExists(noExistedFile) { - t.Errorf("Weird, how could this file exists: %s", noExistedFile) - } -} - -func TestSearchFile(t *testing.T) { - path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) - if err != nil { - t.Error(err) - } - t.Log(path) - - _, err = SearchFile(noExistedFile, ".") - if err == nil { - t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) - } -} - -func TestGrepFile(t *testing.T) { - _, err := GrepFile("", noExistedFile) - if err == nil { - t.Error("expect file-not-existed error, but got nothing") - } - - path := filepath.Join(".", "testdata", "grepe.test") - lines, err := GrepFile(`^\s*[^#]+`, path) - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(lines, []string{"hello", "world"}) { - t.Errorf("expect [hello world], but receive %v", lines) - } -} diff --git a/pkg/server/web/app.go b/pkg/server/web/app.go index ad3ff663..7511c7fe 100644 --- a/pkg/server/web/app.go +++ b/pkg/server/web/app.go @@ -492,15 +492,15 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // The pos means action constant including // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) +func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, opts...) return BeeApp } // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. // the filter's behavior is like stack -func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { - BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) +func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, opts...) return BeeApp } diff --git a/pkg/server/web/filter.go b/pkg/server/web/filter.go index e10faafc..9aab48d6 100644 --- a/pkg/server/web/filter.go +++ b/pkg/server/web/filter.go @@ -45,13 +45,14 @@ type FilterRouter struct { // 2. determining whether or not params need to be reset. func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, + tree: NewTree(), + pattern: pattern, + filterFunc: filter, } - fos := &filterOpts{} + fos := &filterOpts{ + returnOnOutput: true, + } for _, o := range opts { o(fos) diff --git a/pkg/server/web/namespace.go b/pkg/server/web/namespace.go index e59f38c5..a792aa60 100644 --- a/pkg/server/web/namespace.go +++ b/pkg/server/web/namespace.go @@ -91,7 +91,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { a = FinishRouter } for _, f := range filter { - n.handlers.InsertFilter("*", a, f) + n.handlers.InsertFilter("*", a, f, WithReturnOnOutput(true)) } return n } diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index 14ad1484..33b75703 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -423,7 +423,7 @@ func TestInsertFilter(t *testing.T) { testName := "TestInsertFilter" mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true)) if !mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing no variadic params should set returnOnOutput to true", @@ -436,7 +436,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(false)) if mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing false as 1st variadic param should set returnOnOutput to false", @@ -444,7 +444,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true), WithResetParams(true)) if !mux.filters[BeforeRouter][0].resetParams { t.Errorf( "%s: passing true as 2nd variadic param should set resetParams to true", @@ -461,7 +461,7 @@ func TestParamResetFilter(t *testing.T) { mux := NewControllerRegister() - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + mux.InsertFilter("*", BeforeExec, beegoResetParams, WithReturnOnOutput(true), WithResetParams(true)) mux.Get(route, beegoHandleResetParams) @@ -514,8 +514,8 @@ func TestFilterBeforeExec(t *testing.T) { url := "/beforeExec" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -542,7 +542,7 @@ func TestFilterAfterExec(t *testing.T) { mux := NewControllerRegister() mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) @@ -570,10 +570,10 @@ func TestFilterFinishRouter(t *testing.T) { url := "/finishRouter" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -604,7 +604,7 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { url := "/finishRouterMultiFirstOnly" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) mux.Get(url, beegoFilterFunc) @@ -631,8 +631,8 @@ func TestFilterFinishRouterMulti(t *testing.T) { url := "/finishRouterMulti" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 9f73ce46..488729dc 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -15,6 +15,7 @@ package task import ( + "context" "errors" "fmt" "sync" @@ -25,7 +26,10 @@ import ( ) func TestParse(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) err := tk.Run(nil) if err != nil { t.Fatal(err) @@ -39,9 +43,9 @@ func TestParse(t *testing.T) { func TestSpec(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(2) - tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) - tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) - tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func(ctx context.Context) error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func(ctx context.Context) error { fmt.Println("tk3"); wg.Done(); return nil }) AddTask("tk1", tk1) AddTask("tk2", tk2) @@ -58,7 +62,7 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 - task := func() error { + task := func(ctx context.Context) error { cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt))